aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/winograd_input_transform.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/winograd_input_transform.cl')
-rw-r--r--src/core/CL/cl_kernels/winograd_input_transform.cl92
1 files changed, 46 insertions, 46 deletions
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl
index 6e969bd111..5e5b737785 100644
--- a/src/core/CL/cl_kernels/winograd_input_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_input_transform.cl
@@ -23,48 +23,48 @@
*/
#include "helpers.h"
-#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \
- ({ \
- basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \
- basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \
- basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \
- basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \
- basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s0) && (z_cond))); \
- basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##1).s1) && (z_cond))); \
+#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s0) && (z_cond))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s1) && (z_cond))); \
})
-#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \
- ({ \
- basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \
- basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \
- basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \
- basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \
- basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s0))); \
- basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##1).s1))); \
+#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s0))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s1))); \
})
-#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \
- ({ \
- basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s0) && (z_cond))); \
- basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s1) && (z_cond))); \
- basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s2) && (z_cond))); \
- basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s3) && (z_cond))); \
- basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s4) && (z_cond))); \
- basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s5) && (z_cond))); \
- basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s6) && (z_cond))); \
- basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))(((y_cond##0).s7) && (z_cond))); \
+#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s4) && (z_cond))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s5) && (z_cond))); \
+ basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s6) && (z_cond))); \
+ basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s7) && (z_cond))); \
})
-#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \
- ({ \
- basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s0))); \
- basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s1))); \
- basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s2))); \
- basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s3))); \
- basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s4))); \
- basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s5))); \
- basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s6))); \
- basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype, 1))((y_cond) && ((z_cond##0).s7))); \
+#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \
+ ({ \
+ basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
+ basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
+ basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
+ basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
+ basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s4))); \
+ basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s5))); \
+ basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s6))); \
+ basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \
})
#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
@@ -1000,7 +1000,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
const int b = get_global_id(2) / NUM_TILES_Y;
#else // defined(NUM_TILES_Y)
// Index height
- const int z = get_global_id(2);
+ const int z = get_global_id(2);
#endif // defined(NUM_TILES_Y)
#if defined(NUM_TILES_Y)
@@ -1064,12 +1064,12 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0);
#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
- DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
- DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
- DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
- DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
+ DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
+ DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
+ DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
+ DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
+ DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
@@ -1135,7 +1135,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
#if defined(NUM_TILES_Y)
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
#else // defined(NUM_TILES_Y)
- __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
+ __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
#endif // defined(NUM_TILES_Y)
uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
@@ -1354,14 +1354,14 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
const int z = get_global_id(2) % NUM_TILES_Y;
const int b = get_global_id(2) / NUM_TILES_Y;
#else // defined(NUM_TILES_Y)
- const int z = get_global_id(2);
+ const int z = get_global_id(2);
#endif // defined(NUM_TILES_Y)
// Compute input address
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
#else // defined(NUM_TILES_Y)
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
#endif // defined(NUM_TILES_Y)
// Origin coordinates for the width (y) and height (z) in the input tensor