aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/winograd_input_transform.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-10-19 12:49:44 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-10-19 15:36:57 +0000
commitbc6c374f5bc6c17c3e9b5462f5f8c3c5a5e8a13e (patch)
tree054f99e6d8e6c266e2bfd8786b8d07ee2ef60587 /src/core/CL/cl_kernels/winograd_input_transform.cl
parent7333e1f10f5da9dc67b511d326121a843771a107 (diff)
downloadComputeLibrary-bc6c374f5bc6c17c3e9b5462f5f8c3c5a5e8a13e.tar.gz
COMPMID-3740: Remove OpenCL padding: CLWinogradInputTransformKernel
- Remove padding requirement from the OpenCL kernels - Extend test to validate zero padding requirement Change-Id: I1ddf04eba783721858792efb08a2c97f11f7297e Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4206 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/winograd_input_transform.cl')
-rw-r--r--src/core/CL/cl_kernels/winograd_input_transform.cl769
1 files changed, 355 insertions, 414 deletions
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl
index 48a4e0d399..6e969bd111 100644
--- a/src/core/CL/cl_kernels/winograd_input_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_input_transform.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,50 @@
*/
#include "helpers.h"
+#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s0) && (z_cond))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s1) && (z_cond))); \
+ })
+
+#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s0))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s1))); \
+ })
+
+#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s4) && (z_cond))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s5) && (z_cond))); \
+ basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s6) && (z_cond))); \
+ basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s7) && (z_cond))); \
+ })
+
+#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s4))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s5))); \
+ basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s6))); \
+ basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s7))); \
+ })
+
#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
({ \
comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
@@ -945,51 +989,54 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
uint src_stride_w,
uint dst_stride_w)
{
+ // Index channel
const int x = get_global_id(0);
+ // Index width
const int y = get_global_id(1);
#if defined(NUM_TILES_Y)
+ // Index height
const int z = get_global_id(2) % NUM_TILES_Y;
+ // Index batch size
const int b = get_global_id(2) / NUM_TILES_Y;
-#else /* defined(NUM_TILES_Y) */
+#else // defined(NUM_TILES_Y)
+ // Index height
const int z = get_global_id(2);
-#endif /* defined(NUM_TILES_Y) */
+#endif // defined(NUM_TILES_Y)
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
-#else /* defined(NUM_TILES_Y) */
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
-#endif /* defined(NUM_TILES_Y) */
+#else // defined(NUM_TILES_Y)
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
+#endif // defined(NUM_TILES_Y)
- // Clamp coordinates. This clamp is valid for all rows
+ // Origin coordinates for the width (y) and height (z) in the input tensor
int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
- y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1);
- y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1);
+ int4 z_coord0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
+ int2 z_coord1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
- int z_coord;
- int4 valid_y0;
- int2 valid_y1;
+ // Coordinates to use to avoid out-of-bound reads
+ int4 y_coord_valid0 = clamp(y_coord0, (int4)0, (int4)((int)SRC_DIM_1 - 1));
+ int2 y_coord_valid1 = clamp(y_coord1, (int2)0, (int2)((int)SRC_DIM_1 - 1));
+ int4 z_coord_valid0 = clamp(z_coord0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
+ int2 z_coord_valid1 = clamp(z_coord1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
-#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- // Row4
- z_coord = (z * 4) - (int)PAD_TOP + 4;
+ // Boundary conditions
+ int4 y_cond0 = y_coord_valid0 == y_coord0;
+ int2 y_cond1 = y_coord_valid1 == y_coord1;
+ int4 z_cond0 = z_coord_valid0 == z_coord0;
+ int2 z_cond1 = z_coord_valid1 == z_coord1;
- // If z < 0, set y to -1
- valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
- valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
- // If z >= SRC_DIM_2, set y to SRC_DIM_2
- valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- // Clamp z coordinate
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+ DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
- DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d4, y_cond, z_cond1.s0);
DATA_TYPE k0 = d44;
DATA_TYPE k1 = d44;
@@ -1007,44 +1054,24 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- // Row0
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
-
-#if PAD_TOP != 0
- valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
- valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
- valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-#else // PAD_TOP != 0
- valid_y0 = y_coord0;
- valid_y1 = y_coord1;
-#endif // if PAD_TOP == 0, we cannot read out of bound
-
- DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0);
+
#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
- int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
-
- valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0);
- valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0);
- valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2);
-
- z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
- z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
-
- DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z);
- DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z);
- DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z);
- DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z);
- DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z);
- DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z);
+ DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
@@ -1055,20 +1082,14 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
DATA_TYPE out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- // Row2
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
- valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
- valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
- valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d2, y_cond, z_cond0.s2);
out0 += k0;
out1 += k1;
@@ -1113,9 +1134,9 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
// Compute destination address
#if defined(NUM_TILES_Y)
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
-#else /* defined(NUM_TILES_Y) */
- __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
-#endif /* defined(NUM_TILES_Y) */
+#else // defined(NUM_TILES_Y)
+ __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
+#endif // defined(NUM_TILES_Y)
uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
@@ -1133,34 +1154,22 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
dst_addr += dst_plane_stride;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- // Row1
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
- // Row1 can never be out of bounds
- valid_y0 = y_coord0;
- valid_y1 = y_coord1;
-
- DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
-
- // Row3
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
- valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
- valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
- valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+
+ DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d1, y_cond, z_cond0.s1);
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d3, y_cond, z_cond0.s3);
// Compute common parts for the channels between [6, 29]
// Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
@@ -1270,20 +1279,14 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
dst_addr += dst_plane_stride;
// Row5
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
- valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
- valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
- valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
- valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
- DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d5, y_cond, z_cond1.s1);
// Channels [30, 35]
out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
@@ -1350,37 +1353,44 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
#if defined(NUM_TILES_Y)
const int z = get_global_id(2) % NUM_TILES_Y;
const int b = get_global_id(2) / NUM_TILES_Y;
-#else /* defined(NUM_TILES_Y) */
- const int z = get_global_id(2);
-#endif /* defined(NUM_TILES_Y) */
+#else // defined(NUM_TILES_Y)
+ const int z = get_global_id(2);
+#endif // defined(NUM_TILES_Y)
// Compute input address
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
-#else /* defined(NUM_TILES_Y) */
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
-#endif /* defined(NUM_TILES_Y) */
+#else // defined(NUM_TILES_Y)
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
+#endif // defined(NUM_TILES_Y)
-#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
- // Clamp coordinates. This clamp is valid for all rows
- int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
- y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
+ // Origin coordinates for the width (y) and height (z) in the input tensor
+ int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
+ int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
- // Row0
- // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
- int z_coord = z * OUTPUT_TILE_H;
+ // Coordinates to use to avoid out-of-bound reads
+ int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
+ int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
+
+ // Boundary conditions
+ int8 y_cond0 = y_coord_valid0 == y_coord0;
+ int8 z_cond0 = z_coord_valid0 == z_coord0;
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
@@ -1394,27 +1404,20 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
- // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
- int y_coord = y * (int)OUTPUT_TILE_W;
-
- // Row0
- // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
- int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
- int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
- valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
- z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z);
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
@@ -1430,130 +1433,101 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
- // Clamp coordinates. This clamp is valid for all rows
- int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
- y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
-
// Row0
- int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
- int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- // Load the input tile
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Row1
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
// Row2
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
// Row3
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
// Row4
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
// Row5
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
// Row6
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
// Row7
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
- in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
@@ -1722,29 +1696,33 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
#endif /* defined(NUM_TILES_Y) */
-#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ // Origin coordinates for the width (y) and height (z) in the input tensor
+ int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
+ int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
+
+ // Coordinates to use to avoid out-of-bound reads
+ int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
+ int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
- // Clamp coordinates. This clamp is valid for all rows
- int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
- y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
+ // Boundary conditions
+ int8 y_cond0 = y_coord_valid0 == y_coord0;
+ int8 z_cond0 = z_coord_valid0 == z_coord0;
- // Clamp coordinates. This clamp is valid for all columns
- int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
- int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
VEC_DATA_TYPE(DATA_TYPE, 8)
out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
@@ -1758,27 +1736,19 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
- // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
- int y_coord = y * (int)OUTPUT_TILE_W;
-
- // Row0
- // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
- int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
- int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
- valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
- z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
-
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * (int)src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * (int)src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * (int)src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * (int)src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * (int)src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * (int)src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * (int)src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * (int)src_stride_z);
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
@@ -1795,130 +1765,101 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
- // Clamp coordinates. This clamp is valid for all rows
- int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
- y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
-
// Row0
- int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
- int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
+ in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- // Load the input tile
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Row1
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
// Row2
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
// Row3
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
// Row4
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+ in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
// Row5
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+ in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
// Row6
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+ in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
// Row7
- z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
- valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
- valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
- z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
-
- in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
- in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
+ in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = (DATA_TYPE)36.0f * in_row2 - (DATA_TYPE)13.0f * in_row4 + in_row6;