diff options
Diffstat (limited to 'src/core/CL/cl_kernels/winograd_input_transform.cl')
-rw-r--r-- | src/core/CL/cl_kernels/winograd_input_transform.cl | 92 |
1 files changed, 46 insertions, 46 deletions
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl index 6e969bd111..5e5b737785 100644 --- a/src/core/CL/cl_kernels/winograd_input_transform.cl +++ b/src/core/CL/cl_kernels/winograd_input_transform.cl @@ -23,48 +23,48 @@ */ #include "helpers.h" -#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \ - ({ \ - basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \ - basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \ - basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \ - basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \ - basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s0) && (z_cond))); \ - basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s1) && (z_cond))); \ +#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s0) && (z_cond))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s1) && (z_cond))); \ }) -#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \ - ({ \ - basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \ - basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \ - basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \ - basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \ - basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s0))); \ - basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s1))); \ +#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s0))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s1))); \ }) -#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \ - ({ \ - basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \ - basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \ - basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \ - basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \ - basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s4) && (z_cond))); \ - basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s5) && (z_cond))); \ - basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s6) && (z_cond))); \ - basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s7) && (z_cond))); \ +#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s4) && (z_cond))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s5) && (z_cond))); \ + basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s6) && (z_cond))); \ + basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s7) && (z_cond))); \ }) -#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \ - ({ \ - basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \ - basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \ - basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \ - basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \ - basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s4))); \ - basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s5))); \ - basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s6))); \ - basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s7))); \ +#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \ + ({ \ + basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \ + basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \ + basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \ + basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \ + basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s4))); \ + basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s5))); \ + basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s6))); \ + basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \ }) #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \ @@ -1000,7 +1000,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( const int b = get_global_id(2) / NUM_TILES_Y; #else // defined(NUM_TILES_Y) // Index height - const int z = get_global_id(2); + const int z = get_global_id(2); #endif // defined(NUM_TILES_Y) #if defined(NUM_TILES_Y) @@ -1064,12 +1064,12 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0); #else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); - DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); - DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); - DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); - DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); + DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); + DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z); + DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z); + DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z); + DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z); FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond); #endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) @@ -1135,7 +1135,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( #if defined(NUM_TILES_Y) __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w); #else // defined(NUM_TILES_Y) - __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y); + __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y); #endif // defined(NUM_TILES_Y) uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE); @@ -1354,14 +1354,14 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( const int z = get_global_id(2) % NUM_TILES_Y; const int b = get_global_id(2) / NUM_TILES_Y; #else // defined(NUM_TILES_Y) - const int z = get_global_id(2); + const int z = get_global_id(2); #endif // defined(NUM_TILES_Y) // Compute input address #if defined(NUM_TILES_Y) __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w; #else // defined(NUM_TILES_Y) - __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE); #endif // defined(NUM_TILES_Y) // Origin coordinates for the width (y) and height (z) in the input tensor |