aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/im2col.cl
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-06-19 13:09:53 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:34 +0000
commit19ea419e7f14d02aeb208c2fbd5a4ac55f4cb101 (patch)
treefe04ed9d40ebb8b717f63490f672a28c5b27d01e /src/core/CL/cl_kernels/im2col.cl
parentbb71fe50930f5669a7a325e0fa95fee559856793 (diff)
downloadComputeLibrary-19ea419e7f14d02aeb208c2fbd5a4ac55f4cb101.tar.gz
COMPMID-809: Add NHWC data format on CLGEMMConvolutionLayer.
Change-Id: I50e4f5e7d47e21c300f754bee2c216863075b5cf Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/136191 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/im2col.cl')
-rw-r--r--src/core/CL/cl_kernels/im2col.cl9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/core/CL/cl_kernels/im2col.cl b/src/core/CL/cl_kernels/im2col.cl
index c60c9a996c..6f25ad4b7a 100644
--- a/src/core/CL/cl_kernels/im2col.cl
+++ b/src/core/CL/cl_kernels/im2col.cl
@@ -136,6 +136,7 @@ __kernel void im2col1x1_stridex1_dchw(
* @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
* @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
* @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
+ * @note The dilation_x and dilation_y must be passed at compile time using -DDILATION_X and -DDILATION_Y: e.g. -DDILATION_X=1, -DDILATION_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
@@ -182,16 +183,18 @@ __kernel void im2col_generic_nhwc(
for(int yk = 0; yk < KERNEL_HEIGHT; ++yk)
{
- const int y0 = yi + yk;
+ const int dilated_offset_y = yk * DILATION_Y;
+ const int y0 = yi + dilated_offset_y;
if(y0 >= 0 && y0 < SRC_HEIGHT)
{
int xk;
for(xk = 0; xk < KERNEL_WIDTH; xk++)
{
- const int x0 = xi + xk;
+ const int dilated_offset_x = xk * DILATION_X;
+ const int x0 = xi + dilated_offset_x;
if(x0 >= 0 && x0 < SRC_WIDTH)
{
- *((__global DATA_TYPE *)output_ptr) = PTR_TO_VALUE(input_ptr + xk * src_stride_y + yk * src_stride_z, DATA_TYPE);
+ *((__global DATA_TYPE *)output_ptr) = PTR_TO_VALUE(input_ptr + dilated_offset_x * src_stride_y + dilated_offset_y * src_stride_z, DATA_TYPE);
}
else
{