aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-13 17:20:04 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d (patch)
treee7736258428837e3889108909d58592937fe71fd
parent64f1a908841913049ccc0eb941b5b213290d7bf7 (diff)
downloadComputeLibrary-e55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d.tar.gz
COMPMID-1581: Collapse windows
Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h20
-rw-r--r--src/core/CL/cl_kernels/col2im.cl45
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution.cl92
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution_quantized.cl44
-rw-r--r--src/core/CL/kernels/CLCol2ImKernel.cpp18
-rw-r--r--src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp38
-rw-r--r--src/core/CL/kernels/CLFillBorderKernel.cpp5
-rw-r--r--tests/datasets/Col2ImLayerDataset.h4
-rw-r--r--tests/validation/reference/Col2Im.cpp2
9 files changed, 162 insertions, 106 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index e88fd8d75e..6d8e15b8b2 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -176,13 +176,21 @@ inline TensorShape compute_col2im_shape(const ITensorInfo &input, const Size2D &
ARM_COMPUTE_ERROR_ON(input.tensor_shape()[1] != (convolved_dims.area()));
ARM_COMPUTE_ERROR_ON((num_groups > 1) && input.tensor_shape()[2] != num_groups);
- TensorShape col2im_shape{ input.tensor_shape() };
- col2im_shape.set(0, convolved_dims.width);
- col2im_shape.set(1, convolved_dims.height);
- col2im_shape.set(2, input.tensor_shape()[0] * num_groups);
+ const DataLayout data_layout = input.data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
- const unsigned int batch_idx = (batch_size_on_z && num_groups == 1) ? 2 : 3;
- col2im_shape.set(3, input.tensor_shape()[batch_idx]);
+ TensorShape col2im_shape{ input.tensor_shape() };
+ // If batches start on 3rd dimension shift dimensions right by 1 to retain upper tensor shape,
+ // as first three will be override by H,W,C data
+ if(batch_size_on_z && num_groups == 1)
+ {
+ col2im_shape.shift_right(1);
+ }
+ col2im_shape.set(width_idx, convolved_dims.width);
+ col2im_shape.set(height_idx, convolved_dims.height);
+ col2im_shape.set(channel_idx, input.tensor_shape()[0] * num_groups);
return col2im_shape;
}
diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl
index 5e52127f27..b02d07b332 100644
--- a/src/core/CL/cl_kernels/col2im.cl
+++ b/src/core/CL/cl_kernels/col2im.cl
@@ -23,7 +23,7 @@
*/
#include "helpers.h"
-#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)
#if ELEMENT_SIZE == 1
#define COND_DATA_TYPE char
@@ -41,7 +41,7 @@
* @note The width of the input tensor must be passed at compile time using -DWIDTH_INPUT: e.g. -DWIDTH_INPUT=320
* @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=600
* @note The element size must be passed at compile time using -DELEMENT_SIZE: e.g. -DELEMENT_SIZE=4
- * @note In case of grouping the GROUPING flag must be passed at compile time using -DGROUPING
+ * @note The number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -58,15 +58,16 @@
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
+ * @param[in] dst_step_w dst_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
*/
__kernel void col2im(
TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst),
- uint dst_stride_w)
+ TENSOR4D_DECLARATION(dst))
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 0);
const uint xd = get_global_id(1) % WIDTH_OUTPUT; // x coordinate of the destination tensor
const uint yd = get_global_id(1) / WIDTH_OUTPUT; // y coordinate of the destination tensor
@@ -86,27 +87,25 @@ __kernel void col2im(
// If out-of-bound, overwrite with the first element
data = select((VEC_DATA_TYPE(DATA_TYPE, 8))data.s0, data, cond0);
- __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes;
-
-#if defined(GROUPING)
- // Compute output offset (batches on 4th dimension, no need to compute manually)
- int idx = yd * dst_stride_y + xd * dst_stride_x;
+#if NUM_GROUPS > 1
+ // Compute output offset (batches on 4th dimension)
+ int idx = yd * dst_stride_y + xd * dst_stride_x + (get_global_id(2) / NUM_GROUPS) * dst.stride_w;
- const uint group = get_global_id(2); // group ID
+ const uint group = get_global_id(2) % NUM_GROUPS; // group ID
x_clamped += group * WIDTH_INPUT;
-#else /* defined(GROUPING) */
+#else /* defined(NUM_GROUPS > 1 ) */
// Compute output offset (batches on 3rd dimension)
- int idx = yd * dst_stride_y + xd * dst_stride_x + get_global_id(2) * dst_stride_w;
-#endif /* GROUPING */
+ int idx = yd * dst.stride_y + xd * dst.stride_x + get_global_id(2) * dst.stride_w;
+#endif /* NUM_GROUPS > 1 */
// Store value
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s0 * dst_stride_z)) = data.s0;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s1 * dst_stride_z)) = data.s1;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s2 * dst_stride_z)) = data.s2;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s3 * dst_stride_z)) = data.s3;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s4 * dst_stride_z)) = data.s4;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s5 * dst_stride_z)) = data.s5;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s6 * dst_stride_z)) = data.s6;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s7 * dst_stride_z)) = data.s7;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s0 * dst.stride_z)) = data.s0;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s1 * dst.stride_z)) = data.s1;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s2 * dst.stride_z)) = data.s2;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s3 * dst.stride_z)) = data.s3;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s4 * dst.stride_z)) = data.s4;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s5 * dst.stride_z)) = data.s5;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s6 * dst.stride_z)) = data.s6;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s7 * dst.stride_z)) = data.s7;
}
-#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)
diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl
index 23237da562..97b46c47cf 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution.cl
@@ -24,7 +24,7 @@
#include "helpers.h"
-#if defined(DEPTH_MULTIPLIER)
+#if defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(CONV_STRIDE_X)
#if CONV_STRIDE_X == 1
@@ -188,23 +188,28 @@ __kernel void depthwise_convolution_3x3(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y;
- float3 weights_values0 = vload3(0, (__global float *)(weights.ptr + offset.s0));
- float3 weights_values1 = vload3(0, (__global float *)(weights.ptr + offset.s1));
- float3 weights_values2 = vload3(0, (__global float *)(weights.ptr + offset.s2));
+ float3 weights_values0 = vload3(0, (__global float *)(weights_addr + offset.s0));
+ float3 weights_values1 = vload3(0, (__global float *)(weights_addr + offset.s1));
+ float3 weights_values2 = vload3(0, (__global float *)(weights_addr + offset.s2));
float2 pixels = convolution3x3(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2,
weights_values1.s0, weights_values1.s1, weights_values1.s2,
weights_values2.s0, weights_values2.s1, weights_values2.s2);
#if defined(HAS_BIAS)
- pixels += (float2)(*((__global float *)(biases.ptr + get_global_id(2) * biases_stride_x)));
+ pixels += (float2)(*((__global float *)(biases.ptr + channel * biases_stride_x)));
#endif //defined(HAS_BIAS)
vstore2(pixels, 0, (__global float *)dst.ptr);
@@ -307,15 +312,19 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
float2 pixels0 = 0.0f;
float2 pixels1 = 0.0f;
float2 pixels2 = 0.0f;
float2 pixels3 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y));
@@ -346,7 +355,7 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32(
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- float bias = *((__global float *)(vector_offset(&biases, get_global_id(2))));
+ float bias = *((__global float *)(vector_offset(&biases, channel)));
pixels0 += (float2)bias;
pixels1 += (float2)bias;
@@ -404,13 +413,17 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
float2 pixels0 = 0.0f;
float2 pixels1 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y));
@@ -439,7 +452,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- float bias = *((__global float *)(vector_offset(&biases, get_global_id(2))));
+ float bias = *((__global float *)(vector_offset(&biases, channel)));
pixels0 += (float2)bias;
pixels1 += (float2)bias;
@@ -449,7 +462,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
vstore2(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
}
-#endif // defined(DEPTH_MULTIPLIER)
+#endif // defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(NCHW)
#define in_stride_x src_stride_x
@@ -617,7 +630,7 @@ __kernel void depthwise_vector_to_tensor(
#endif //defined(CONV_WIDTH) && defined(CONV_HEIGHT) && defined(DATA_TYPE)
-#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(CONV_STRIDE_X)
#if CONV_STRIDE_X == 1
#define convolution1x3_f16 convolution1x3_stride_1_f16
@@ -781,23 +794,28 @@ __kernel void depthwise_convolution_3x3_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y;
- half3 weights_values0 = vload3(0, (__global half *)(weights.ptr + offset.s0));
- half3 weights_values1 = vload3(0, (__global half *)(weights.ptr + offset.s1));
- half3 weights_values2 = vload3(0, (__global half *)(weights.ptr + offset.s2));
+ half3 weights_values0 = vload3(0, (__global half *)(weights_addr + offset.s0));
+ half3 weights_values1 = vload3(0, (__global half *)(weights_addr + offset.s1));
+ half3 weights_values2 = vload3(0, (__global half *)(weights_addr + offset.s2));
half4 pixels = convolution3x3_f16(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2,
weights_values1.s0, weights_values1.s1, weights_values1.s2,
weights_values2.s0, weights_values2.s1, weights_values2.s2);
#if defined(HAS_BIAS)
- pixels += (half4)(*((__global half *)(biases.ptr + get_global_id(2) * biases_stride_x)));
+ pixels += (half4)(*((__global half *)(biases.ptr + channel * biases_stride_x)));
#endif //defined(HAS_BIAS)
vstore4(pixels, 0, (__global half *)dst.ptr);
@@ -849,12 +867,16 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- half bias = *((__global half *)(vector_offset(&biases, get_global_id(2))));
+ half bias = *((__global half *)(vector_offset(&biases, channel)));
#endif /* defined(HAS_BIAS) */
half4 pixels0 = 0.0f;
@@ -862,8 +884,9 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16(
half4 pixels2 = 0.0f;
half4 pixels3 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y));
@@ -948,19 +971,24 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- half bias = *((__global half *)(vector_offset(&biases, get_global_id(2))));
+ half bias = *((__global half *)(vector_offset(&biases, channel)));
#endif /* defined(HAS_BIAS) */
half4 pixels0 = 0.0f;
half4 pixels1 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data ( Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y));
@@ -994,7 +1022,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16(
vstore4(pixels0, 0, (__global half *)(dst.ptr + 0 * dst_stride_y));
vstore4(pixels1, 0, (__global half *)(dst.ptr + 1 * dst_stride_y));
}
-#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) && defined(DATA_TYPE)
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index 71889830c5..b3edc52612 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -45,7 +45,7 @@
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER)
+#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if CONV_STRIDE_X > 3
#error "Stride X not supported"
@@ -129,18 +129,25 @@ __kernel void depthwise_convolution_3x3_quantized_nchw(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+ int bias_value = *((__global int *)(vector_offset(&biases, channel));
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
- uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
- uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
- uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+ uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
int8 values0 = 0;
int8 sum0 = 0;
@@ -337,18 +344,25 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+
#if defined(HAS_BIAS)
- Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- const int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+ const int bias_value = *((__global int *)(vector_offset(&biases, channel)));
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
- uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
- uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
- uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+ uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
uchar8 left0, middle0, right0;
uchar8 left1, middle1, right1;
@@ -501,7 +515,7 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw(
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) */
+#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) */
#if defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp
index 74bbb9b4df..d748745999 100644
--- a/src/core/CL/kernels/CLCol2ImKernel.cpp
+++ b/src/core/CL/kernels/CLCol2ImKernel.cpp
@@ -106,7 +106,7 @@ void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, const
build_opts.add_option("-DELEMENT_SIZE=" + support::cpp11::to_string(input->info()->element_size()));
build_opts.add_option("-DWIDTH_INPUT=" + support::cpp11::to_string(input->info()->dimension(0)));
build_opts.add_option("-DWIDTH_OUTPUT=" + support::cpp11::to_string(_convolved_dims.width));
- build_opts.add_option_if(num_groups > 1, "-DGROUPING");
+ build_opts.add_option("-DNUM_GROUPS=" + support::cpp11::to_string(num_groups));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("col2im", build_opts.options()));
@@ -143,22 +143,26 @@ void CLCol2ImKernel::run(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
+ bool is_collapsed = false;
+ bool is_collapsed_out = false;
+
Window out_window;
out_window.use_tensor_dimensions(_output->info()->tensor_shape());
- Window slice = window.first_slice_window_3D();
- Window slice_out = out_window.first_slice_window_3D();
+ Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ, &is_collapsed);
+ Window collapsed_out = out_window.collapse_if_possible(out_window, 3, &is_collapsed_out);
- unsigned int idx = 2 * num_arguments_per_3D_tensor();
- _kernel.setArg<cl_uint>(idx++, _output->info()->strides_in_bytes()[3]);
+ ARM_COMPUTE_ERROR_ON(is_collapsed != is_collapsed_out);
+ Window slice = collapsed.first_slice_window_3D();
+ Window slice_out = collapsed_out.first_slice_window_4D();
do
{
// Set inputs
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
- add_3D_tensor_argument(idx, _output, slice_out);
+ add_4D_tensor_argument(idx, _output, slice_out);
enqueue(queue, *this, slice, lws_hint());
}
- while(window.slide_window_slice_3D(slice) && out_window.slide_window_slice_3D(slice_out));
+ while(collapsed.slide_window_slice_3D(slice) && collapsed_out.slide_window_slice_4D(slice_out));
}
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
index a40aa2856c..de7e2b8737 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
@@ -225,8 +225,17 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input,
_conv_pad_top = conv_info.pad_top();
_border_size = BorderSize(_conv_pad_top, conv_info.pad_right(), conv_info.pad_bottom(), _conv_pad_left);
+ // Configure kernel window
+ std::string kernel_name;
+ const GPUTarget gpu_target = get_target();
+
+ auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, depth_multiplier, gpu_target, kernel_name);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ ICLKernel::configure_internal(win_config.second);
+
// Set build options
CLBuildOptions build_opts;
+ build_opts.add_option("-DDST_CHANNELS=" + support::cpp11::to_string(_output->info()->tensor_shape().z()));
build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier));
build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x));
build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS");
@@ -273,15 +282,6 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input,
}
}
}
-
- // Configure kernel window
- std::string kernel_name;
- const GPUTarget gpu_target = get_target();
-
- auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, depth_multiplier, gpu_target, kernel_name);
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure_internal(win_config.second);
-
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Set config_id for enabling LWS tuning
@@ -316,15 +316,17 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::run(const Window &window, cl::Com
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
+ Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+
// Create input window and adjust
- Window win_in = window;
- win_in.adjust(Window::DimX, -_conv_pad_left, true);
- win_in.adjust(Window::DimY, -_conv_pad_top, true);
- win_in.set_dimension_step(Window::DimX, window.x().step() * _conv_stride_x);
- win_in.set_dimension_step(Window::DimY, window.y().step() * _conv_stride_y);
-
- Window slice_in = win_in.first_slice_window_3D();
- Window slice_out = window.first_slice_window_3D();
+ Window collapsed_in = collapsed;
+ collapsed_in.adjust(Window::DimX, -_conv_pad_left, true);
+ collapsed_in.adjust(Window::DimY, -_conv_pad_top, true);
+ collapsed_in.set_dimension_step(Window::DimX, collapsed_in.x().step() * _conv_stride_x);
+ collapsed_in.set_dimension_step(Window::DimY, collapsed_in.y().step() * _conv_stride_y);
+
+ Window slice_in = collapsed_in.first_slice_window_3D();
+ Window slice_out = collapsed.first_slice_window_3D();
Window slice_weights = window.first_slice_window_3D();
slice_weights.set_dimension_step(Window::DimX, 0);
slice_weights.set_dimension_step(Window::DimY, 0);
@@ -347,5 +349,5 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::run(const Window &window, cl::Com
enqueue(queue, *this, slice_out, lws_hint());
}
- while(window.slide_window_slice_3D(slice_out) && win_in.slide_window_slice_3D(slice_in));
+ while(collapsed.slide_window_slice_3D(slice_out) && collapsed_in.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLFillBorderKernel.cpp b/src/core/CL/kernels/CLFillBorderKernel.cpp
index baf6bb6024..69206678d0 100644
--- a/src/core/CL/kernels/CLFillBorderKernel.cpp
+++ b/src/core/CL/kernels/CLFillBorderKernel.cpp
@@ -168,7 +168,8 @@ void CLFillBorderKernel::run(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
- Window slice = window.first_slice_window_3D();
+ Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+ Window slice = collapsed.first_slice_window_3D();
do
{
@@ -176,5 +177,5 @@ void CLFillBorderKernel::run(const Window &window, cl::CommandQueue &queue)
add_3D_tensor_argument(idx, _tensor, slice);
enqueue(queue, *this, slice, cl::NullRange);
}
- while(window.slide_window_slice_3D(slice));
+ while(collapsed.slide_window_slice_3D(slice));
}
diff --git a/tests/datasets/Col2ImLayerDataset.h b/tests/datasets/Col2ImLayerDataset.h
index 96a3cab134..b39cedbde6 100644
--- a/tests/datasets/Col2ImLayerDataset.h
+++ b/tests/datasets/Col2ImLayerDataset.h
@@ -128,7 +128,7 @@ public:
add_config(TensorShape(8U, 16U, 3U, 1U), 4U, 4U, 3U);
add_config(TensorShape(8U, 16U, 3U, 3U), 4U, 4U, 3U);
add_config(TensorShape(12U, 20U, 4U, 1U), 5U, 4U, 4U);
- add_config(TensorShape(12U, 20U, 4U, 3U), 5U, 4U, 4U);
+ add_config(TensorShape(12U, 20U, 4U, 3U, 2U), 5U, 4U, 4U);
}
};
@@ -142,7 +142,7 @@ public:
add_config(TensorShape(333U, 280U, 1U, 77U), 14U, 20U, 1U);
add_config(TensorShape(333U, 280U, 77U, 1U), 14U, 20U, 1U);
add_config(TensorShape(120U, 300U, 8U, 3U), 20U, 15U, 8U);
- add_config(TensorShape(233U, 300U, 8U, 3U), 20U, 15U, 8U);
+ add_config(TensorShape(233U, 300U, 8U, 3U, 2U), 20U, 15U, 8U);
add_config(TensorShape(333U, 280U, 12U, 5U), 20U, 14U, 12U);
add_config(TensorShape(177U, 300U, 12U, 5U), 15U, 20U, 12U);
add_config(TensorShape(450U, 400U, 16U, 5U), 20U, 20U, 16U);
diff --git a/tests/validation/reference/Col2Im.cpp b/tests/validation/reference/Col2Im.cpp
index 90e488f928..53969d4725 100644
--- a/tests/validation/reference/Col2Im.cpp
+++ b/tests/validation/reference/Col2Im.cpp
@@ -40,7 +40,7 @@ SimpleTensor<T> col2im(const SimpleTensor<T> &src, const TensorShape &dst_shape,
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
// Compute reference
- const size_t batches = dst_shape[3];
+ const size_t batches = dst_shape.total_size() / (dst_shape.x() * dst_shape.y() * dst_shape.z());
const size_t src_width = src.shape().x();
const size_t src_height = src.shape().y();