aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-08-14 11:26:37 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitdef665a1a2e92baa1cfb192b65425b91ff6046b3 (patch)
tree03271d3e78190d81709618f40191d105b7c94917 /src/core/CL/cl_kernels
parentfc2817dc0436ef2d5064df0a061aafd3d324d894 (diff)
downloadComputeLibrary-def665a1a2e92baa1cfb192b65425b91ff6046b3.tar.gz
COMPMID-474 - Add support for QS8/QS16 DirectConvolution CL
Change-Id: I537e4acbc02c8d880ff8630ea62223e0f1a1dda3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/82875 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r--src/core/CL/cl_kernels/direct_convolution1x1.cl25
-rw-r--r--src/core/CL/cl_kernels/direct_convolution3x3.cl65
-rw-r--r--src/core/CL/cl_kernels/fixed_point.h10
3 files changed, 71 insertions, 29 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution1x1.cl b/src/core/CL/cl_kernels/direct_convolution1x1.cl
index ec0551b018..66c618e033 100644
--- a/src/core/CL/cl_kernels/direct_convolution1x1.cl
+++ b/src/core/CL/cl_kernels/direct_convolution1x1.cl
@@ -23,6 +23,23 @@
*/
#include "helpers.h"
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
+
+// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
+MULQ_SAT_IMPL(qs32x8, qs32x8)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define CONVERT_SAT(a, b) ((a))
+
+#endif /* FIXED_POINT_POSITION */
+
#if STRIDE_X == 3
#define INPUT_PIXEL_STR(data_size) extract_input_stride3_##data_size
#define INPUT_PIXEL(data_size) INPUT_PIXEL_STR(data_size)
@@ -165,7 +182,7 @@ __kernel void direct_convolution1x1(
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif /* defined(HAS_BIAS) */
- VEC_DATA_TYPE(DATA_TYPE, 8)
+ VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
pixels = 0;
const uint z_index = get_global_id(2);
@@ -177,15 +194,15 @@ __kernel void direct_convolution1x1(
DATA_TYPE weight = *(__global DATA_TYPE *)weights.ptr;
VEC_DATA_TYPE(DATA_TYPE, 8)
input_pixel = INPUT_PIXEL(DATA_SIZE)((__global DATA_TYPE *)src.ptr);
- pixels += weight * input_pixel;
+ pixels = ADD_OP(pixels, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))weight, input_pixel));
src.ptr += src_stride_z;
weights.ptr += weights_stride_z;
}
#ifdef HAS_BIAS
- pixels += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index)));
+ pixels = ADD_OP(pixels, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index))));
#endif /* defined(HAS_BIAS) */
- vstore8(pixels, 0, (__global DATA_TYPE *)dst.ptr);
+ vstore8(CONVERT_SAT(pixels, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
}
#endif // defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH) \ No newline at end of file
diff --git a/src/core/CL/cl_kernels/direct_convolution3x3.cl b/src/core/CL/cl_kernels/direct_convolution3x3.cl
index 51886efe64..4da7c39e26 100644
--- a/src/core/CL/cl_kernels/direct_convolution3x3.cl
+++ b/src/core/CL/cl_kernels/direct_convolution3x3.cl
@@ -23,6 +23,23 @@
*/
#include "helpers.h"
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
+
+// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
+MULQ_SAT_IMPL(qs32x8, qs32x8)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define CONVERT_SAT(a, b) ((a))
+
+#endif /* FIXED_POINT_POSITION */
+
#if STRIDE_X == 1
#define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)
#elif STRIDE_X == 2 /* STRIDE_X == 1 */
@@ -31,31 +48,31 @@
#error "STRIDE_X larger than 2 is not supported"
#endif /* STRIDE_X == 2 */
-#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr) \
- ({ \
- VEC_DATA_TYPE(DATA_TYPE, 4) \
- weights_values0 = vload4(0, weights_row_ptr); \
- VEC_DATA_TYPE(DATA_TYPE, 8) \
- src0 = vload8(0, src_row_ptr); \
- VEC_DATA_TYPE(DATA_TYPE, 2) \
- src1 = vload2(0, src_row_ptr + 8); \
+#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr) \
+ ({ \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ weights_values0 = vload4(0, weights_row_ptr); \
+ VEC_DATA_TYPE(DATA_TYPE, 8) \
+ src0 = vload8(0, src_row_ptr); \
+ VEC_DATA_TYPE(DATA_TYPE, 2) \
+ src1 = vload2(0, src_row_ptr + 8); \
\
- acc += src0 * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0; \
- acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
- acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+ acc = ADD_OP(acc, MUL_OP(src0, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
})
-#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr) \
- ({ \
- VEC_DATA_TYPE(DATA_TYPE, 4) \
- weights_values0 = vload4(0, weights_row_ptr); \
- VEC_DATA_TYPE(DATA_TYPE, 16) \
- src0 = vload16(0, src_row_ptr); \
- DATA_TYPE src1 = *(src_row_ptr + 16); \
+#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr) \
+ ({ \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ weights_values0 = vload4(0, weights_row_ptr); \
+ VEC_DATA_TYPE(DATA_TYPE, 16) \
+ src0 = vload16(0, src_row_ptr); \
+ DATA_TYPE src1 = *(src_row_ptr + 16); \
\
- acc += src0.even * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0; \
- acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
- acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+ acc = ADD_OP(acc, MUL_OP(src0.even, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
})
/** This kernel performs a direct convolution to convolve the low three dimensions.
@@ -108,7 +125,7 @@ __kernel void direct_convolution3x3(
Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
- VEC_DATA_TYPE(DATA_TYPE, 8)
+ VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
pixels0 = 0;
__global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
@@ -130,9 +147,9 @@ __kernel void direct_convolution3x3(
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- pixels0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index)));
+ pixels0 = ADD_OP(pixels0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index))));
#endif /* defined(HAS_BIAS) */
- vstore8(pixels0, 0, (__global DATA_TYPE *)dst.ptr);
+ vstore8(CONVERT_SAT(pixels0, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
}
#endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH) \ No newline at end of file
diff --git a/src/core/CL/cl_kernels/fixed_point.h b/src/core/CL/cl_kernels/fixed_point.h
index 7038d40e16..d35a46f428 100644
--- a/src/core/CL/cl_kernels/fixed_point.h
+++ b/src/core/CL/cl_kernels/fixed_point.h
@@ -168,6 +168,11 @@ ADDQ_SAT_IMPL(qs16x2)
ADDQ_SAT_IMPL(qs16x4)
ADDQ_SAT_IMPL(qs16x8)
ADDQ_SAT_IMPL(qs16x16)
+ADDQ_SAT_IMPL(qs32x1)
+ADDQ_SAT_IMPL(qs32x2)
+ADDQ_SAT_IMPL(qs32x4)
+ADDQ_SAT_IMPL(qs32x8)
+ADDQ_SAT_IMPL(qs32x16)
#define ADD_SAT_OP_EXPAND_STR(a, b, type, size) add_sat_##type##x##size((a), (b))
#define ADD_SAT_OP_EXPAND(a, b, type, size) ADD_SAT_OP_EXPAND_STR(a, b, type, size)
@@ -213,6 +218,8 @@ SUBQ_SAT_IMPL(qs16x16)
return CONVERT((res >> (itype)fixed_point_position), type); \
}
+MULQ_IMPL(qs8x8, qs16x8)
+MULQ_IMPL(qs16x8, qs32x8)
MULQ_IMPL(qs8x16, qs16x16)
MULQ_IMPL(qs16x16, qs32x16)
@@ -234,8 +241,9 @@ MULQ_IMPL(qs16x16, qs32x16)
return CONVERT_SAT((res >> (itype)fixed_point_position), type); \
}
-MULQ_SAT_IMPL(qs8x16, qs16x16)
+MULQ_SAT_IMPL(qs8x8, qs16x8)
MULQ_SAT_IMPL(qs16x8, qs32x8)
+MULQ_SAT_IMPL(qs8x16, qs16x16)
MULQ_SAT_IMPL(qs16x16, qs32x16)
#define MUL_SAT_OP_EXPAND_STR(a, b, type, size, position) mul_sat_##type##x##size((a), (b), (position))