From e55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 13 Sep 2018 17:20:04 +0100 Subject: COMPMID-1581: Collapse windows Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572 Reviewed-by: Anthony Barbier Tested-by: bsgcomp --- arm_compute/core/utils/misc/ShapeCalculator.h | 20 +++-- src/core/CL/cl_kernels/col2im.cl | 45 ++++++----- src/core/CL/cl_kernels/depthwise_convolution.cl | 92 ++++++++++++++-------- .../cl_kernels/depthwise_convolution_quantized.cl | 44 +++++++---- src/core/CL/kernels/CLCol2ImKernel.cpp | 18 +++-- .../CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp | 38 ++++----- src/core/CL/kernels/CLFillBorderKernel.cpp | 5 +- tests/datasets/Col2ImLayerDataset.h | 4 +- tests/validation/reference/Col2Im.cpp | 2 +- 9 files changed, 162 insertions(+), 106 deletions(-) diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index e88fd8d75e..6d8e15b8b2 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -176,13 +176,21 @@ inline TensorShape compute_col2im_shape(const ITensorInfo &input, const Size2D & ARM_COMPUTE_ERROR_ON(input.tensor_shape()[1] != (convolved_dims.area())); ARM_COMPUTE_ERROR_ON((num_groups > 1) && input.tensor_shape()[2] != num_groups); - TensorShape col2im_shape{ input.tensor_shape() }; - col2im_shape.set(0, convolved_dims.width); - col2im_shape.set(1, convolved_dims.height); - col2im_shape.set(2, input.tensor_shape()[0] * num_groups); + const DataLayout data_layout = input.data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const unsigned int batch_idx = (batch_size_on_z && num_groups == 1) ? 2 : 3; - col2im_shape.set(3, input.tensor_shape()[batch_idx]); + TensorShape col2im_shape{ input.tensor_shape() }; + // If batches start on 3rd dimension shift dimensions right by 1 to retain upper tensor shape, + // as first three will be override by H,W,C data + if(batch_size_on_z && num_groups == 1) + { + col2im_shape.shift_right(1); + } + col2im_shape.set(width_idx, convolved_dims.width); + col2im_shape.set(height_idx, convolved_dims.height); + col2im_shape.set(channel_idx, input.tensor_shape()[0] * num_groups); return col2im_shape; } diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl index 5e52127f27..b02d07b332 100644 --- a/src/core/CL/cl_kernels/col2im.cl +++ b/src/core/CL/cl_kernels/col2im.cl @@ -23,7 +23,7 @@ */ #include "helpers.h" -#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) +#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS) #if ELEMENT_SIZE == 1 #define COND_DATA_TYPE char @@ -41,7 +41,7 @@ * @note The width of the input tensor must be passed at compile time using -DWIDTH_INPUT: e.g. -DWIDTH_INPUT=320 * @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=600 * @note The element size must be passed at compile time using -DELEMENT_SIZE: e.g. -DELEMENT_SIZE=4 - * @note In case of grouping the GROUPING flag must be passed at compile time using -DGROUPING + * @note The number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4 * * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) @@ -58,15 +58,16 @@ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes) + * @param[in] dst_step_w dst_stride_w * number of elements along W processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor */ __kernel void col2im( TENSOR3D_DECLARATION(src), - TENSOR3D_DECLARATION(dst), - uint dst_stride_w) + TENSOR4D_DECLARATION(dst)) { Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); + Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 0); const uint xd = get_global_id(1) % WIDTH_OUTPUT; // x coordinate of the destination tensor const uint yd = get_global_id(1) / WIDTH_OUTPUT; // y coordinate of the destination tensor @@ -86,27 +87,25 @@ __kernel void col2im( // If out-of-bound, overwrite with the first element data = select((VEC_DATA_TYPE(DATA_TYPE, 8))data.s0, data, cond0); - __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes; - -#if defined(GROUPING) - // Compute output offset (batches on 4th dimension, no need to compute manually) - int idx = yd * dst_stride_y + xd * dst_stride_x; +#if NUM_GROUPS > 1 + // Compute output offset (batches on 4th dimension) + int idx = yd * dst_stride_y + xd * dst_stride_x + (get_global_id(2) / NUM_GROUPS) * dst.stride_w; - const uint group = get_global_id(2); // group ID + const uint group = get_global_id(2) % NUM_GROUPS; // group ID x_clamped += group * WIDTH_INPUT; -#else /* defined(GROUPING) */ +#else /* defined(NUM_GROUPS > 1 ) */ // Compute output offset (batches on 3rd dimension) - int idx = yd * dst_stride_y + xd * dst_stride_x + get_global_id(2) * dst_stride_w; -#endif /* GROUPING */ + int idx = yd * dst.stride_y + xd * dst.stride_x + get_global_id(2) * dst.stride_w; +#endif /* NUM_GROUPS > 1 */ // Store value - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s0 * dst_stride_z)) = data.s0; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s1 * dst_stride_z)) = data.s1; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s2 * dst_stride_z)) = data.s2; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s3 * dst_stride_z)) = data.s3; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s4 * dst_stride_z)) = data.s4; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s5 * dst_stride_z)) = data.s5; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s6 * dst_stride_z)) = data.s6; - *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s7 * dst_stride_z)) = data.s7; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s0 * dst.stride_z)) = data.s0; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s1 * dst.stride_z)) = data.s1; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s2 * dst.stride_z)) = data.s2; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s3 * dst.stride_z)) = data.s3; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s4 * dst.stride_z)) = data.s4; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s5 * dst.stride_z)) = data.s5; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s6 * dst.stride_z)) = data.s6; + *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s7 * dst.stride_z)) = data.s7; } -#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) +#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS) diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl index 23237da562..97b46c47cf 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution.cl @@ -24,7 +24,7 @@ #include "helpers.h" -#if defined(DEPTH_MULTIPLIER) +#if defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) #if defined(CONV_STRIDE_X) #if CONV_STRIDE_X == 1 @@ -188,23 +188,28 @@ __kernel void depthwise_convolution_3x3( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); #if defined(HAS_BIAS) Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); #endif //defined(HAS_BIAS) - src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; - float3 weights_values0 = vload3(0, (__global float *)(weights.ptr + offset.s0)); - float3 weights_values1 = vload3(0, (__global float *)(weights.ptr + offset.s1)); - float3 weights_values2 = vload3(0, (__global float *)(weights.ptr + offset.s2)); + float3 weights_values0 = vload3(0, (__global float *)(weights_addr + offset.s0)); + float3 weights_values1 = vload3(0, (__global float *)(weights_addr + offset.s1)); + float3 weights_values2 = vload3(0, (__global float *)(weights_addr + offset.s2)); float2 pixels = convolution3x3(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2, weights_values1.s0, weights_values1.s1, weights_values1.s2, weights_values2.s0, weights_values2.s1, weights_values2.s2); #if defined(HAS_BIAS) - pixels += (float2)(*((__global float *)(biases.ptr + get_global_id(2) * biases_stride_x))); + pixels += (float2)(*((__global float *)(biases.ptr + channel * biases_stride_x))); #endif //defined(HAS_BIAS) vstore2(pixels, 0, (__global float *)dst.ptr); @@ -307,15 +312,19 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); float2 pixels0 = 0.0f; float2 pixels1 = 0.0f; float2 pixels2 = 0.0f; float2 pixels3 = 0.0f; - __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; + __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; // Load the weights float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y)); @@ -346,7 +355,7 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( #ifdef HAS_BIAS Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - float bias = *((__global float *)(vector_offset(&biases, get_global_id(2)))); + float bias = *((__global float *)(vector_offset(&biases, channel))); pixels0 += (float2)bias; pixels1 += (float2)bias; @@ -404,13 +413,17 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); float2 pixels0 = 0.0f; float2 pixels1 = 0.0f; - __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; + __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; // Load the weights float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y)); @@ -439,7 +452,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( #ifdef HAS_BIAS Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - float bias = *((__global float *)(vector_offset(&biases, get_global_id(2)))); + float bias = *((__global float *)(vector_offset(&biases, channel))); pixels0 += (float2)bias; pixels1 += (float2)bias; @@ -449,7 +462,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( vstore2(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y)); } -#endif // defined(DEPTH_MULTIPLIER) +#endif // defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) #if defined(NCHW) #define in_stride_x src_stride_x @@ -617,7 +630,7 @@ __kernel void depthwise_vector_to_tensor( #endif //defined(CONV_WIDTH) && defined(CONV_HEIGHT) && defined(DATA_TYPE) -#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) +#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) #if defined(CONV_STRIDE_X) #if CONV_STRIDE_X == 1 #define convolution1x3_f16 convolution1x3_stride_1_f16 @@ -781,23 +794,28 @@ __kernel void depthwise_convolution_3x3_f16( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); #if defined(HAS_BIAS) Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); #endif //defined(HAS_BIAS) - src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; - half3 weights_values0 = vload3(0, (__global half *)(weights.ptr + offset.s0)); - half3 weights_values1 = vload3(0, (__global half *)(weights.ptr + offset.s1)); - half3 weights_values2 = vload3(0, (__global half *)(weights.ptr + offset.s2)); + half3 weights_values0 = vload3(0, (__global half *)(weights_addr + offset.s0)); + half3 weights_values1 = vload3(0, (__global half *)(weights_addr + offset.s1)); + half3 weights_values2 = vload3(0, (__global half *)(weights_addr + offset.s2)); half4 pixels = convolution3x3_f16(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2, weights_values1.s0, weights_values1.s1, weights_values1.s2, weights_values2.s0, weights_values2.s1, weights_values2.s2); #if defined(HAS_BIAS) - pixels += (half4)(*((__global half *)(biases.ptr + get_global_id(2) * biases_stride_x))); + pixels += (half4)(*((__global half *)(biases.ptr + channel * biases_stride_x))); #endif //defined(HAS_BIAS) vstore4(pixels, 0, (__global half *)dst.ptr); @@ -849,12 +867,16 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); + + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; #ifdef HAS_BIAS Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - half bias = *((__global half *)(vector_offset(&biases, get_global_id(2)))); + half bias = *((__global half *)(vector_offset(&biases, channel))); #endif /* defined(HAS_BIAS) */ half4 pixels0 = 0.0f; @@ -862,8 +884,9 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( half4 pixels2 = 0.0f; half4 pixels3 = 0.0f; - __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; + __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; // Load the weights half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y)); @@ -948,19 +971,24 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); + + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; #ifdef HAS_BIAS Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - half bias = *((__global half *)(vector_offset(&biases, get_global_id(2)))); + half bias = *((__global half *)(vector_offset(&biases, channel))); #endif /* defined(HAS_BIAS) */ half4 pixels0 = 0.0f; half4 pixels1 = 0.0f; - __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Load relevant input and weights data ( Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; + __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; // Load the weights half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y)); @@ -994,7 +1022,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( vstore4(pixels0, 0, (__global half *)(dst.ptr + 0 * dst_stride_y)); vstore4(pixels1, 0, (__global half *)(dst.ptr + 1 * dst_stride_y)); } -#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) +#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) #if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) && defined(DATA_TYPE) diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl index 71889830c5..b3edc52612 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl @@ -45,7 +45,7 @@ #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) -#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) +#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) #if CONV_STRIDE_X > 3 #error "Stride X not supported" @@ -129,18 +129,25 @@ __kernel void depthwise_convolution_3x3_quantized_nchw( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); + + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + #if defined(HAS_BIAS) Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2)))); + int bias_value = *((__global int *)(vector_offset(&biases, channel)); #endif //defined(HAS_BIAS) - src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y); - uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y); - uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y); + uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y); + uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y); + uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y); int8 values0 = 0; int8 sum0 = 0; @@ -337,18 +344,25 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw( { Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); - Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights); + Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); + + // Extract channel and linearized batch indices + const int channel = get_global_id(2) % DST_CHANNELS; + const int batch = get_global_id(2) / DST_CHANNELS; + #if defined(HAS_BIAS) - Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); + Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); - const int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2)))); + const int bias_value = *((__global int *)(vector_offset(&biases, channel))); #endif //defined(HAS_BIAS) - src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) + src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y); - uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y); - uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y); + uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y); + uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y); + uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y); uchar8 left0, middle0, right0; uchar8 left1, middle1, right1; @@ -501,7 +515,7 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw( #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) -#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) */ +#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) */ #if defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp index 74bbb9b4df..d748745999 100644 --- a/src/core/CL/kernels/CLCol2ImKernel.cpp +++ b/src/core/CL/kernels/CLCol2ImKernel.cpp @@ -106,7 +106,7 @@ void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, const build_opts.add_option("-DELEMENT_SIZE=" + support::cpp11::to_string(input->info()->element_size())); build_opts.add_option("-DWIDTH_INPUT=" + support::cpp11::to_string(input->info()->dimension(0))); build_opts.add_option("-DWIDTH_OUTPUT=" + support::cpp11::to_string(_convolved_dims.width)); - build_opts.add_option_if(num_groups > 1, "-DGROUPING"); + build_opts.add_option("-DNUM_GROUPS=" + support::cpp11::to_string(num_groups)); _kernel = static_cast(CLKernelLibrary::get().create_kernel("col2im", build_opts.options())); @@ -143,22 +143,26 @@ void CLCol2ImKernel::run(const Window &window, cl::CommandQueue &queue) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window); + bool is_collapsed = false; + bool is_collapsed_out = false; + Window out_window; out_window.use_tensor_dimensions(_output->info()->tensor_shape()); - Window slice = window.first_slice_window_3D(); - Window slice_out = out_window.first_slice_window_3D(); + Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ, &is_collapsed); + Window collapsed_out = out_window.collapse_if_possible(out_window, 3, &is_collapsed_out); - unsigned int idx = 2 * num_arguments_per_3D_tensor(); - _kernel.setArg(idx++, _output->info()->strides_in_bytes()[3]); + ARM_COMPUTE_ERROR_ON(is_collapsed != is_collapsed_out); + Window slice = collapsed.first_slice_window_3D(); + Window slice_out = collapsed_out.first_slice_window_4D(); do { // Set inputs unsigned int idx = 0; add_3D_tensor_argument(idx, _input, slice); - add_3D_tensor_argument(idx, _output, slice_out); + add_4D_tensor_argument(idx, _output, slice_out); enqueue(queue, *this, slice, lws_hint()); } - while(window.slide_window_slice_3D(slice) && out_window.slide_window_slice_3D(slice_out)); + while(collapsed.slide_window_slice_3D(slice) && collapsed_out.slide_window_slice_4D(slice_out)); } diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp index a40aa2856c..de7e2b8737 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp @@ -225,8 +225,17 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, _conv_pad_top = conv_info.pad_top(); _border_size = BorderSize(_conv_pad_top, conv_info.pad_right(), conv_info.pad_bottom(), _conv_pad_left); + // Configure kernel window + std::string kernel_name; + const GPUTarget gpu_target = get_target(); + + auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, depth_multiplier, gpu_target, kernel_name); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure_internal(win_config.second); + // Set build options CLBuildOptions build_opts; + build_opts.add_option("-DDST_CHANNELS=" + support::cpp11::to_string(_output->info()->tensor_shape().z())); build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier)); build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x)); build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS"); @@ -273,15 +282,6 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, } } } - - // Configure kernel window - std::string kernel_name; - const GPUTarget gpu_target = get_target(); - - auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, depth_multiplier, gpu_target, kernel_name); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - ICLKernel::configure_internal(win_config.second); - _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); // Set config_id for enabling LWS tuning @@ -316,15 +316,17 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::run(const Window &window, cl::Com ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); + Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); + // Create input window and adjust - Window win_in = window; - win_in.adjust(Window::DimX, -_conv_pad_left, true); - win_in.adjust(Window::DimY, -_conv_pad_top, true); - win_in.set_dimension_step(Window::DimX, window.x().step() * _conv_stride_x); - win_in.set_dimension_step(Window::DimY, window.y().step() * _conv_stride_y); - - Window slice_in = win_in.first_slice_window_3D(); - Window slice_out = window.first_slice_window_3D(); + Window collapsed_in = collapsed; + collapsed_in.adjust(Window::DimX, -_conv_pad_left, true); + collapsed_in.adjust(Window::DimY, -_conv_pad_top, true); + collapsed_in.set_dimension_step(Window::DimX, collapsed_in.x().step() * _conv_stride_x); + collapsed_in.set_dimension_step(Window::DimY, collapsed_in.y().step() * _conv_stride_y); + + Window slice_in = collapsed_in.first_slice_window_3D(); + Window slice_out = collapsed.first_slice_window_3D(); Window slice_weights = window.first_slice_window_3D(); slice_weights.set_dimension_step(Window::DimX, 0); slice_weights.set_dimension_step(Window::DimY, 0); @@ -347,5 +349,5 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::run(const Window &window, cl::Com enqueue(queue, *this, slice_out, lws_hint()); } - while(window.slide_window_slice_3D(slice_out) && win_in.slide_window_slice_3D(slice_in)); + while(collapsed.slide_window_slice_3D(slice_out) && collapsed_in.slide_window_slice_3D(slice_in)); } diff --git a/src/core/CL/kernels/CLFillBorderKernel.cpp b/src/core/CL/kernels/CLFillBorderKernel.cpp index baf6bb6024..69206678d0 100644 --- a/src/core/CL/kernels/CLFillBorderKernel.cpp +++ b/src/core/CL/kernels/CLFillBorderKernel.cpp @@ -168,7 +168,8 @@ void CLFillBorderKernel::run(const Window &window, cl::CommandQueue &queue) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window); - Window slice = window.first_slice_window_3D(); + Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); + Window slice = collapsed.first_slice_window_3D(); do { @@ -176,5 +177,5 @@ void CLFillBorderKernel::run(const Window &window, cl::CommandQueue &queue) add_3D_tensor_argument(idx, _tensor, slice); enqueue(queue, *this, slice, cl::NullRange); } - while(window.slide_window_slice_3D(slice)); + while(collapsed.slide_window_slice_3D(slice)); } diff --git a/tests/datasets/Col2ImLayerDataset.h b/tests/datasets/Col2ImLayerDataset.h index 96a3cab134..b39cedbde6 100644 --- a/tests/datasets/Col2ImLayerDataset.h +++ b/tests/datasets/Col2ImLayerDataset.h @@ -128,7 +128,7 @@ public: add_config(TensorShape(8U, 16U, 3U, 1U), 4U, 4U, 3U); add_config(TensorShape(8U, 16U, 3U, 3U), 4U, 4U, 3U); add_config(TensorShape(12U, 20U, 4U, 1U), 5U, 4U, 4U); - add_config(TensorShape(12U, 20U, 4U, 3U), 5U, 4U, 4U); + add_config(TensorShape(12U, 20U, 4U, 3U, 2U), 5U, 4U, 4U); } }; @@ -142,7 +142,7 @@ public: add_config(TensorShape(333U, 280U, 1U, 77U), 14U, 20U, 1U); add_config(TensorShape(333U, 280U, 77U, 1U), 14U, 20U, 1U); add_config(TensorShape(120U, 300U, 8U, 3U), 20U, 15U, 8U); - add_config(TensorShape(233U, 300U, 8U, 3U), 20U, 15U, 8U); + add_config(TensorShape(233U, 300U, 8U, 3U, 2U), 20U, 15U, 8U); add_config(TensorShape(333U, 280U, 12U, 5U), 20U, 14U, 12U); add_config(TensorShape(177U, 300U, 12U, 5U), 15U, 20U, 12U); add_config(TensorShape(450U, 400U, 16U, 5U), 20U, 20U, 16U); diff --git a/tests/validation/reference/Col2Im.cpp b/tests/validation/reference/Col2Im.cpp index 90e488f928..53969d4725 100644 --- a/tests/validation/reference/Col2Im.cpp +++ b/tests/validation/reference/Col2Im.cpp @@ -40,7 +40,7 @@ SimpleTensor col2im(const SimpleTensor &src, const TensorShape &dst_shape, SimpleTensor dst{ dst_shape, src.data_type(), 1 }; // Compute reference - const size_t batches = dst_shape[3]; + const size_t batches = dst_shape.total_size() / (dst_shape.x() * dst_shape.y() * dst_shape.z()); const size_t src_width = src.shape().x(); const size_t src_height = src.shape().y(); -- cgit v1.2.1