aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/pooling_layer.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/pooling_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/pooling_layer.cl30
1 files changed, 14 insertions, 16 deletions
diff --git a/src/core/CL/cl_kernels/pooling_layer.cl b/src/core/CL/cl_kernels/pooling_layer.cl
index b30145b11e..ebf7c5c078 100644
--- a/src/core/CL/cl_kernels/pooling_layer.cl
+++ b/src/core/CL/cl_kernels/pooling_layer.cl
@@ -724,10 +724,22 @@ __kernel void pooling_layer_MxN_nhwc(
VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE)
res0 = INITIAL_VALUE;
-#if POOL_SIZE_X == SRC_WIDTH && POOL_SIZE_Y == SRC_HEIGHT
- // Global pooling path
+ int idx_in_w = idx_out_w * STRIDE_X - PAD_X;
+ int idx_in_h = idx_out_h * STRIDE_Y - PAD_Y;
+ int pool_x_s = max((int)0, -idx_in_w);
+ int pool_x_e = min((int)POOL_SIZE_X, (int)SRC_WIDTH - idx_in_w);
+ int pool_y_s = max((int)0, -idx_in_h);
+ int pool_y_e = min((int)POOL_SIZE_Y, (int)SRC_HEIGHT - idx_in_h);
+
+#if defined(EXCLUDE_PADDING)
+ int filter_size = (pool_y_e - pool_y_s) * (pool_x_e - pool_x_s);
+#else // defined(EXCLUDE_PADDING)
int filter_size = POOL_SIZE_X * POOL_SIZE_Y;
+#endif // defined(EXCLUDE_PADDING)
+
+#if POOL_SIZE_X == SRC_WIDTH && POOL_SIZE_Y == SRC_HEIGHT
+ // Global pooling path
#pragma unroll 8
for(int y = 0; y < POOL_SIZE_X * POOL_SIZE_Y; ++y)
@@ -752,20 +764,6 @@ __kernel void pooling_layer_MxN_nhwc(
}
#else // POOL_SIZE_X == SRC_WIDTH && POOL_SIZE_Y == SRC_HEIGHT
- int idx_in_w = idx_out_w * STRIDE_X - PAD_X;
- int idx_in_h = idx_out_h * STRIDE_Y - PAD_Y;
-
- int pool_x_s = max((int)0, -idx_in_w);
- int pool_x_e = min((int)POOL_SIZE_X, (int)SRC_WIDTH - idx_in_w);
- int pool_y_s = max((int)0, -idx_in_h);
- int pool_y_e = min((int)POOL_SIZE_Y, (int)SRC_HEIGHT - idx_in_h);
-
-#if defined(EXCLUDE_PADDING)
- int filter_size = (pool_y_e - pool_y_s) * (pool_x_e - pool_x_s);
-#else // defined(EXCLUDE_PADDING)
- int filter_size = POOL_SIZE_X * POOL_SIZE_Y;
-#endif // defined(EXCLUDE_PADDING)
-
for(int y = pool_y_s; y < pool_y_e; ++y)
{
for(int x = pool_x_s; x < pool_x_e; ++x)