From bc6c374f5bc6c17c3e9b5462f5f8c3c5a5e8a13e Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 19 Oct 2020 12:49:44 +0100 Subject: COMPMID-3740: Remove OpenCL padding: CLWinogradInputTransformKernel - Remove padding requirement from the OpenCL kernels - Extend test to validate zero padding requirement Change-Id: I1ddf04eba783721858792efb08a2c97f11f7297e Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4206 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- src/core/CL/cl_kernels/winograd_input_transform.cl | 769 ++++++++++----------- .../CL/kernels/CLWinogradInputTransformKernel.cpp | 7 +- tests/validation/CL/Winograd.cpp | 39 +- 3 files changed, 394 insertions(+), 421 deletions(-) diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl index 48a4e0d399..6e969bd111 100644 --- a/src/core/CL/cl_kernels/winograd_input_transform.cl +++ b/src/core/CL/cl_kernels/winograd_input_transform.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,50 @@ */ #include "helpers.h" +#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s0) && (z_cond))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s1) && (z_cond))); \ + }) + +#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s0))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s1))); \ + }) + +#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s4) && (z_cond))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s5) && (z_cond))); \ + basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s6) && (z_cond))); \ + basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s7) && (z_cond))); \ + }) + +#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s4))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s5))); \ + basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s6))); \ + basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s7))); \ + }) + #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \ ({ \ comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \ @@ -945,51 +989,54 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( uint src_stride_w, uint dst_stride_w) { + // Index channel const int x = get_global_id(0); + // Index width const int y = get_global_id(1); #if defined(NUM_TILES_Y) + // Index height const int z = get_global_id(2) % NUM_TILES_Y; + // Index batch size const int b = get_global_id(2) / NUM_TILES_Y; -#else /* defined(NUM_TILES_Y) */ +#else // defined(NUM_TILES_Y) + // Index height const int z = get_global_id(2); -#endif /* defined(NUM_TILES_Y) */ +#endif // defined(NUM_TILES_Y) #if defined(NUM_TILES_Y) __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w; -#else /* defined(NUM_TILES_Y) */ - __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); -#endif /* defined(NUM_TILES_Y) */ +#else // defined(NUM_TILES_Y) + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); +#endif // defined(NUM_TILES_Y) - // Clamp coordinates. This clamp is valid for all rows + // Origin coordinates for the width (y) and height (z) in the input tensor int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT; int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT; - y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1); - y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1); + int4 z_coord0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP; + int2 z_coord1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP; - int z_coord; - int4 valid_y0; - int2 valid_y1; + // Coordinates to use to avoid out-of-bound reads + int4 y_coord_valid0 = clamp(y_coord0, (int4)0, (int4)((int)SRC_DIM_1 - 1)); + int2 y_coord_valid1 = clamp(y_coord1, (int2)0, (int2)((int)SRC_DIM_1 - 1)); + int4 z_coord_valid0 = clamp(z_coord0, (int4)0, (int4)((int)SRC_DIM_2 - 1)); + int2 z_coord_valid1 = clamp(z_coord1, (int2)0, (int2)((int)SRC_DIM_2 - 1)); -#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - // Row4 - z_coord = (z * 4) - (int)PAD_TOP + 4; + // Boundary conditions + int4 y_cond0 = y_coord_valid0 == y_coord0; + int2 y_cond1 = y_coord_valid1 == y_coord1; + int4 z_cond0 = z_coord_valid0 == z_coord0; + int2 z_cond1 = z_coord_valid1 == z_coord1; - // If z < 0, set y to -1 - valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0); - valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0); - // If z >= SRC_DIM_2, set y to SRC_DIM_2 - valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2); - valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2); +#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - // Clamp z coordinate - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); + DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); - DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d4, y_cond, z_cond1.s0); DATA_TYPE k0 = d44; DATA_TYPE k1 = d44; @@ -1007,44 +1054,24 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) #if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - // Row0 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0; - -#if PAD_TOP != 0 - valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0); - valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0); - valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2); - valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); -#else // PAD_TOP != 0 - valid_y0 = y_coord0; - valid_y1 = y_coord1; -#endif // if PAD_TOP == 0, we cannot read out of bound - - DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); + DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0); + #else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP; - int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP; - - valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0); - valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0); - valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2); - valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2); - - z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1)); - z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1)); - - DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z); - DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z); - DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z); - DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z); - DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z); - DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z); + DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond); #endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) DATA_TYPE out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04; @@ -1055,20 +1082,14 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( DATA_TYPE out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05; #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - // Row2 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2; - valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0); - valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0); - valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2); - valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); + DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d2, y_cond, z_cond0.s2); out0 += k0; out1 += k1; @@ -1113,9 +1134,9 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( // Compute destination address #if defined(NUM_TILES_Y) __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w); -#else /* defined(NUM_TILES_Y) */ - __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y); -#endif /* defined(NUM_TILES_Y) */ +#else // defined(NUM_TILES_Y) + __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y); +#endif // defined(NUM_TILES_Y) uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE); @@ -1133,34 +1154,22 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( dst_addr += dst_plane_stride; #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - // Row1 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1; - // Row1 can never be out of bounds - valid_y0 = y_coord0; - valid_y1 = y_coord1; - - DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); - - // Row3 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3; - valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0); - valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0); - valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2); - valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); + DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + + DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d1, y_cond, z_cond0.s1); + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d3, y_cond, z_cond0.s3); // Compute common parts for the channels between [6, 29] // Channels [6, 11]: [out10, out11, out12, out13, out14, out15] @@ -1270,20 +1279,14 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( dst_addr += dst_plane_stride; // Row5 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5; - valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0); - valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0); - valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2); - valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); - DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); + DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d5, y_cond, z_cond1.s1); // Channels [30, 35] out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34; @@ -1350,37 +1353,44 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( #if defined(NUM_TILES_Y) const int z = get_global_id(2) % NUM_TILES_Y; const int b = get_global_id(2) / NUM_TILES_Y; -#else /* defined(NUM_TILES_Y) */ - const int z = get_global_id(2); -#endif /* defined(NUM_TILES_Y) */ +#else // defined(NUM_TILES_Y) + const int z = get_global_id(2); +#endif // defined(NUM_TILES_Y) // Compute input address #if defined(NUM_TILES_Y) __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w; -#else /* defined(NUM_TILES_Y) */ - __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); -#endif /* defined(NUM_TILES_Y) */ +#else // defined(NUM_TILES_Y) + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); +#endif // defined(NUM_TILES_Y) -#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) - // Clamp coordinates. This clamp is valid for all rows - int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; - y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1); + // Origin coordinates for the width (y) and height (z) in the input tensor + int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; + int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP; - // Row0 - // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels - int z_coord = z * OUTPUT_TILE_H; + // Coordinates to use to avoid out-of-bound reads + int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1)); + int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1)); + + // Boundary conditions + int8 y_cond0 = y_coord_valid0 == y_coord0; + int8 z_cond0 = z_coord_valid0 == z_coord0; + +#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) // Load the input tile VEC_DATA_TYPE(DATA_TYPE, 8) in_row0; - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0); // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) @@ -1394,27 +1404,20 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) - // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels - int y_coord = y * (int)OUTPUT_TILE_W; - - // Row0 - // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels - int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP; - int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1 - valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 - z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate // Load the input tile VEC_DATA_TYPE(DATA_TYPE, 8) in_row0; - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z); + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond); // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) @@ -1430,130 +1433,101 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( VEC_DATA_TYPE(DATA_TYPE, 8) in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7; - // Clamp coordinates. This clamp is valid for all rows - int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; - y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1); - // Row0 - int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0; - int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1 - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - // Load the input tile - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0); // Row1 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1); // Row2 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2); // Row3 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3); // Row4 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4); // Row5 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5); // Row6 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6); // Row7 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); - in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); + in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7); VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4; @@ -1722,29 +1696,33 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc( __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); #endif /* defined(NUM_TILES_Y) */ -#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) + // Origin coordinates for the width (y) and height (z) in the input tensor + int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; + int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP; + + // Coordinates to use to avoid out-of-bound reads + int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1)); + int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1)); - // Clamp coordinates. This clamp is valid for all rows - int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; - y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1); + // Boundary conditions + int8 y_cond0 = y_coord_valid0 == y_coord0; + int8 z_cond0 = z_coord_valid0 == z_coord0; - // Clamp coordinates. This clamp is valid for all columns - int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0; - int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1 - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); +#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) // Load the input tile VEC_DATA_TYPE(DATA_TYPE, 8) in_row0; - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0); VEC_DATA_TYPE(DATA_TYPE, 8) out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f; @@ -1758,27 +1736,19 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc( OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0); #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) - // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels - int y_coord = y * (int)OUTPUT_TILE_W; - - // Row0 - // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels - int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP; - int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1 - valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 - z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate - // Load the input tile VEC_DATA_TYPE(DATA_TYPE, 8) in_row0; - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * (int)src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * (int)src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * (int)src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * (int)src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * (int)src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * (int)src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * (int)src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * (int)src_stride_z); + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond); // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) @@ -1795,130 +1765,101 @@ __kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc( VEC_DATA_TYPE(DATA_TYPE, 8) in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7; - // Clamp coordinates. This clamp is valid for all rows - int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; - y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1); - // Row0 - int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0; - int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1 - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate + in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - // Load the input tile - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0); // Row1 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1); // Row2 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2); // Row3 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3); // Row4 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4); // Row5 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5); // Row6 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6); // Row7 - z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7; - valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); - valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); - z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); - - in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z); - in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z); + in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7); VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact0 = (DATA_TYPE)36.0f * in_row2 - (DATA_TYPE)13.0f * in_row4 + in_row6; diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp index 6b1b86a777..c4c2b08a81 100644 --- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp +++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp @@ -87,11 +87,6 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen AccessWindowRectangle input_access(input, -conv_info.pad_left(), -conv_info.pad_top(), num_elems_read_per_iteration_x, num_elems_read_per_iteration_y); window_changed = update_window_and_padding(win, input_access); } - else - { - AccessWindowStatic input_access(input, 0, -1, input->dimension(0), input->dimension(1) + 1); - window_changed = update_window_and_padding(win, input_access); - } Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, win); @@ -141,7 +136,7 @@ void CLWinogradInputTransformKernel::configure(const CLCompileContext &compile_c } else { - _border_size = BorderSize(1U, 0U, 1U, 0); + _border_size = BorderSize(); } // Compute the number of output tiles along the x and y direction of size "output_tile_size" diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp index 771acf9461..d1522f3e7f 100644 --- a/tests/validation/CL/Winograd.cpp +++ b/tests/validation/CL/Winograd.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -182,6 +182,29 @@ const auto ActivationFunctionsSmallDataset = framework::dataset::make("Activatio ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LEAKY_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::SOFT_RELU) }); + +/** Zero padding test */ +bool validate_zero_padding(unsigned int width, unsigned height) +{ + TensorShape shape(width, height, 11, 1); + + WinogradInfo winograd_info = WinogradInfo(Size2D(4U, 4U), Size2D(5U, 5U), Size2D(width, height), PadStrideInfo(), DataLayout::NHWC); + + // Create tensors + CLTensor src = create_tensor(shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC); + CLTensor dst; + + src.info()->set_quantization_info(QuantizationInfo(1.f / 256.f, 0)); + dst.info()->set_quantization_info(QuantizationInfo(1.f / 256.f, 0)); + + CLWinogradInputTransform input_transform; + + input_transform.configure(&src, &dst, winograd_info); + + // Padding can be added along rhs and bias's X dimension + return src.info()->padding().empty() && dst.info()->padding().empty(); +} + } // namespace using namespace arm_compute::misc::shape_calculator; @@ -190,6 +213,20 @@ TEST_SUITE(CL) TEST_SUITE(Winograd) TEST_SUITE(InputTransform) + +/** Validate zero padding tests + * + * A series of validation tests to check that no padding is added + */ +DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( +framework::dataset::make("Width", { 32U, 37U, 12U, 1U }), +framework::dataset::make("Height", { 13U, 27U, 19U, 1U })), +width, height) +{ + bool status = validate_zero_padding(width, height); + ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS); +} + DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("InputInfo",{ TensorInfo(TensorShape(53U, 21U, 5U, 3U), 1, DataType::F16), // F16 not supported -- cgit v1.2.1