aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-13 17:20:04 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d (patch)
treee7736258428837e3889108909d58592937fe71fd /src/core/CL/cl_kernels
parent64f1a908841913049ccc0eb941b5b213290d7bf7 (diff)
downloadComputeLibrary-e55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d.tar.gz
COMPMID-1581: Collapse windows
Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r--src/core/CL/cl_kernels/col2im.cl45
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution.cl92
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution_quantized.cl44
3 files changed, 111 insertions, 70 deletions
diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl
index 5e52127f27..b02d07b332 100644
--- a/src/core/CL/cl_kernels/col2im.cl
+++ b/src/core/CL/cl_kernels/col2im.cl
@@ -23,7 +23,7 @@
*/
#include "helpers.h"
-#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)
#if ELEMENT_SIZE == 1
#define COND_DATA_TYPE char
@@ -41,7 +41,7 @@
* @note The width of the input tensor must be passed at compile time using -DWIDTH_INPUT: e.g. -DWIDTH_INPUT=320
* @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=600
* @note The element size must be passed at compile time using -DELEMENT_SIZE: e.g. -DELEMENT_SIZE=4
- * @note In case of grouping the GROUPING flag must be passed at compile time using -DGROUPING
+ * @note The number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -58,15 +58,16 @@
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
+ * @param[in] dst_step_w dst_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
*/
__kernel void col2im(
TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst),
- uint dst_stride_w)
+ TENSOR4D_DECLARATION(dst))
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 0);
const uint xd = get_global_id(1) % WIDTH_OUTPUT; // x coordinate of the destination tensor
const uint yd = get_global_id(1) / WIDTH_OUTPUT; // y coordinate of the destination tensor
@@ -86,27 +87,25 @@ __kernel void col2im(
// If out-of-bound, overwrite with the first element
data = select((VEC_DATA_TYPE(DATA_TYPE, 8))data.s0, data, cond0);
- __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes;
-
-#if defined(GROUPING)
- // Compute output offset (batches on 4th dimension, no need to compute manually)
- int idx = yd * dst_stride_y + xd * dst_stride_x;
+#if NUM_GROUPS > 1
+ // Compute output offset (batches on 4th dimension)
+ int idx = yd * dst_stride_y + xd * dst_stride_x + (get_global_id(2) / NUM_GROUPS) * dst.stride_w;
- const uint group = get_global_id(2); // group ID
+ const uint group = get_global_id(2) % NUM_GROUPS; // group ID
x_clamped += group * WIDTH_INPUT;
-#else /* defined(GROUPING) */
+#else /* defined(NUM_GROUPS > 1 ) */
// Compute output offset (batches on 3rd dimension)
- int idx = yd * dst_stride_y + xd * dst_stride_x + get_global_id(2) * dst_stride_w;
-#endif /* GROUPING */
+ int idx = yd * dst.stride_y + xd * dst.stride_x + get_global_id(2) * dst.stride_w;
+#endif /* NUM_GROUPS > 1 */
// Store value
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s0 * dst_stride_z)) = data.s0;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s1 * dst_stride_z)) = data.s1;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s2 * dst_stride_z)) = data.s2;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s3 * dst_stride_z)) = data.s3;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s4 * dst_stride_z)) = data.s4;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s5 * dst_stride_z)) = data.s5;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s6 * dst_stride_z)) = data.s6;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s7 * dst_stride_z)) = data.s7;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s0 * dst.stride_z)) = data.s0;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s1 * dst.stride_z)) = data.s1;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s2 * dst.stride_z)) = data.s2;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s3 * dst.stride_z)) = data.s3;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s4 * dst.stride_z)) = data.s4;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s5 * dst.stride_z)) = data.s5;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s6 * dst.stride_z)) = data.s6;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s7 * dst.stride_z)) = data.s7;
}
-#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)
diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl
index 23237da562..97b46c47cf 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution.cl
@@ -24,7 +24,7 @@
#include "helpers.h"
-#if defined(DEPTH_MULTIPLIER)
+#if defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(CONV_STRIDE_X)
#if CONV_STRIDE_X == 1
@@ -188,23 +188,28 @@ __kernel void depthwise_convolution_3x3(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y;
- float3 weights_values0 = vload3(0, (__global float *)(weights.ptr + offset.s0));
- float3 weights_values1 = vload3(0, (__global float *)(weights.ptr + offset.s1));
- float3 weights_values2 = vload3(0, (__global float *)(weights.ptr + offset.s2));
+ float3 weights_values0 = vload3(0, (__global float *)(weights_addr + offset.s0));
+ float3 weights_values1 = vload3(0, (__global float *)(weights_addr + offset.s1));
+ float3 weights_values2 = vload3(0, (__global float *)(weights_addr + offset.s2));
float2 pixels = convolution3x3(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2,
weights_values1.s0, weights_values1.s1, weights_values1.s2,
weights_values2.s0, weights_values2.s1, weights_values2.s2);
#if defined(HAS_BIAS)
- pixels += (float2)(*((__global float *)(biases.ptr + get_global_id(2) * biases_stride_x)));
+ pixels += (float2)(*((__global float *)(biases.ptr + channel * biases_stride_x)));
#endif //defined(HAS_BIAS)
vstore2(pixels, 0, (__global float *)dst.ptr);
@@ -307,15 +312,19 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
float2 pixels0 = 0.0f;
float2 pixels1 = 0.0f;
float2 pixels2 = 0.0f;
float2 pixels3 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y));
@@ -346,7 +355,7 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32(
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- float bias = *((__global float *)(vector_offset(&biases, get_global_id(2))));
+ float bias = *((__global float *)(vector_offset(&biases, channel)));
pixels0 += (float2)bias;
pixels1 += (float2)bias;
@@ -404,13 +413,17 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
float2 pixels0 = 0.0f;
float2 pixels1 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y));
@@ -439,7 +452,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- float bias = *((__global float *)(vector_offset(&biases, get_global_id(2))));
+ float bias = *((__global float *)(vector_offset(&biases, channel)));
pixels0 += (float2)bias;
pixels1 += (float2)bias;
@@ -449,7 +462,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32(
vstore2(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
}
-#endif // defined(DEPTH_MULTIPLIER)
+#endif // defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(NCHW)
#define in_stride_x src_stride_x
@@ -617,7 +630,7 @@ __kernel void depthwise_vector_to_tensor(
#endif //defined(CONV_WIDTH) && defined(CONV_HEIGHT) && defined(DATA_TYPE)
-#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(CONV_STRIDE_X)
#if CONV_STRIDE_X == 1
#define convolution1x3_f16 convolution1x3_stride_1_f16
@@ -781,23 +794,28 @@ __kernel void depthwise_convolution_3x3_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y;
- half3 weights_values0 = vload3(0, (__global half *)(weights.ptr + offset.s0));
- half3 weights_values1 = vload3(0, (__global half *)(weights.ptr + offset.s1));
- half3 weights_values2 = vload3(0, (__global half *)(weights.ptr + offset.s2));
+ half3 weights_values0 = vload3(0, (__global half *)(weights_addr + offset.s0));
+ half3 weights_values1 = vload3(0, (__global half *)(weights_addr + offset.s1));
+ half3 weights_values2 = vload3(0, (__global half *)(weights_addr + offset.s2));
half4 pixels = convolution3x3_f16(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2,
weights_values1.s0, weights_values1.s1, weights_values1.s2,
weights_values2.s0, weights_values2.s1, weights_values2.s2);
#if defined(HAS_BIAS)
- pixels += (half4)(*((__global half *)(biases.ptr + get_global_id(2) * biases_stride_x)));
+ pixels += (half4)(*((__global half *)(biases.ptr + channel * biases_stride_x)));
#endif //defined(HAS_BIAS)
vstore4(pixels, 0, (__global half *)dst.ptr);
@@ -849,12 +867,16 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- half bias = *((__global half *)(vector_offset(&biases, get_global_id(2))));
+ half bias = *((__global half *)(vector_offset(&biases, channel)));
#endif /* defined(HAS_BIAS) */
half4 pixels0 = 0.0f;
@@ -862,8 +884,9 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16(
half4 pixels2 = 0.0f;
half4 pixels3 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y));
@@ -948,19 +971,24 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- half bias = *((__global half *)(vector_offset(&biases, get_global_id(2))));
+ half bias = *((__global half *)(vector_offset(&biases, channel)));
#endif /* defined(HAS_BIAS) */
half4 pixels0 = 0.0f;
half4 pixels1 = 0.0f;
- __global uchar *weights_addr = (__global uchar *)weights.ptr;
- __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data ( Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
+ __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
// Load the weights
half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y));
@@ -994,7 +1022,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16(
vstore4(pixels0, 0, (__global half *)(dst.ptr + 0 * dst_stride_y));
vstore4(pixels1, 0, (__global half *)(dst.ptr + 1 * dst_stride_y));
}
-#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT) && defined(DATA_TYPE)
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index 71889830c5..b3edc52612 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -45,7 +45,7 @@
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER)
+#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
#if CONV_STRIDE_X > 3
#error "Stride X not supported"
@@ -129,18 +129,25 @@ __kernel void depthwise_convolution_3x3_quantized_nchw(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+ int bias_value = *((__global int *)(vector_offset(&biases, channel));
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
- uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
- uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
- uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+ uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
int8 values0 = 0;
int8 sum0 = 0;
@@ -337,18 +344,25 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw(
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
- Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+ // Extract channel and linearized batch indices
+ const int channel = get_global_id(2) % DST_CHANNELS;
+ const int batch = get_global_id(2) / DST_CHANNELS;
+
#if defined(HAS_BIAS)
- Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- const int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+ const int bias_value = *((__global int *)(vector_offset(&biases, channel)));
#endif //defined(HAS_BIAS)
- src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+ // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+ src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+ __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
- uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
- uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
- uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+ uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
uchar8 left0, middle0, right0;
uchar8 left1, middle1, right1;
@@ -501,7 +515,7 @@ __kernel void depthwise_convolution_3x3_quantized_dot8_nchw(
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) */
+#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) */
#if defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)