aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/cl_kernels/pooling_layer.cl14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/core/CL/cl_kernels/pooling_layer.cl b/src/core/CL/cl_kernels/pooling_layer.cl
index e69c3c35e9..680e947149 100644
--- a/src/core/CL/cl_kernels/pooling_layer.cl
+++ b/src/core/CL/cl_kernels/pooling_layer.cl
@@ -786,6 +786,8 @@ __kernel void pooling_layer_MxN_nhwc(
}
#endif // defined(POOL_SIZE_X) && defined(POOL_SIZE_Y)
+#define SELECT_TYPE SELECT_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE)
+
/** Performs pooling layer of size equal to 2. This OpenCL kernel can perform the following pooling types:
* -# max, -DPOOL_MAX must be passed at compile time
* -# max extracting the max index, -DPOOL_MAX and -DEXTRACT_MAX_INDEX must be passed at compile time
@@ -899,10 +901,16 @@ __kernel void pooling_layer_2x2_nhwc(
#if !defined(POOL_MAX)
if(filter_size != 4)
{
+ SELECT_TYPE cond_w_s = (SELECT_TYPE)idx_in_w < (SELECT_TYPE)0;
+ SELECT_TYPE cond_w_e = (SELECT_TYPE)idx_in_w >= (SELECT_TYPE)(SRC_WIDTH - 1);
+ SELECT_TYPE cond_h_s = (SELECT_TYPE)idx_in_h < (SELECT_TYPE)0;
+ SELECT_TYPE cond_h_e = (SELECT_TYPE)idx_in_h >= (SELECT_TYPE)(SRC_HEIGHT - 1);
+
// Make invalid the values loaded if the x or y coordinate was clamped (out-of-bound)
- data1 = select(data1, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))(pool_x_e == pool_x_s));
- data2 = select(data2, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))(pool_y_e == pool_y_s));
- data3 = select(data3, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))((pool_x_e == pool_x_s) || (pool_y_e == pool_y_s)));
+ data0 = select(data0, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_TYPE)(cond_w_s | cond_h_s));
+ data1 = select(data1, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_TYPE)(cond_w_e | cond_h_s));
+ data2 = select(data2, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_TYPE)(cond_w_s | cond_h_e));
+ data3 = select(data3, (VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE))INITIAL_VALUE, (SELECT_TYPE)(cond_w_e | cond_h_e));
}
#endif // !defined(POOL_MAX)