From aa95ddc2abb7cef0b2edd03f7c4c9d9c6b9d7cf4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 21 Jul 2020 22:45:13 +0100 Subject: COMPMID-3535: 9x9 Direct convolution support for CL and NHWC * Supported strides 1 and 2 Signed-off-by: Georgios Pinitas Change-Id: I4b9f087c0c328234159b2d1eacc2e465b3bb3c54 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3603 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../CL/kernels/CLDirectConvolutionLayerKernel.h | 4 +- .../CL/cl_kernels/direct_convolution_quantized.cl | 140 ++++++++++++++++++++- .../CL/kernels/CLDirectConvolutionLayerKernel.cpp | 18 +-- tests/validation/CL/DirectConvolutionLayer.cpp | 2 +- 4 files changed, 145 insertions(+), 19 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h index b01ce6c0e8..5281a0c306 100644 --- a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h +++ b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h @@ -54,7 +54,7 @@ public: * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 * 3x3 convolution with stride_x = 1/2, stride_y = 1/2 * 5x5 convolution with stride_x = 1/2, stride_y = 1/2 - * 9x9 convolution with stride_x = 1/2, stride_y = 1/2, data_layout=NHWC + * 9x9 convolution with stride_x = 1/2, stride_y = 1/2 * * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32. @@ -74,7 +74,7 @@ public: * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 * 3x3 convolution with stride_x = 1/2, stride_y = 1/2 * 5x5 convolution with stride_x = 1/2, stride_y = 1/2 - * 9x9 convolution with stride_x = 1/2, stride_y = 1/2, data_layout=NHWC + * 9x9 convolution with stride_x = 1/2, stride_y = 1/2 * * @param[in] compile_context The compile context to be used. * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], diff --git a/src/core/CL/cl_kernels/direct_convolution_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_quantized.cl index ed1b7cfe2a..8237fe1700 100644 --- a/src/core/CL/cl_kernels/direct_convolution_quantized.cl +++ b/src/core/CL/cl_kernels/direct_convolution_quantized.cl @@ -33,7 +33,113 @@ #if defined(DATA_LAYOUT_NHWC) -#if KERNEL_SIZE == 5 +#if KERNEL_SIZE == 9 + +#if STRIDE_X == 1 +#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr) +#elif STRIDE_X == 2 +#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr) +#else /* STRIDE_X not equals 1 or 2 */ +#error "STRIDE_X larger than 2 is not supported" +#endif /* STRIDE_X */ + +#define CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr) \ + ({ \ + int8 weights_values0 = 0; \ + int weights_value1 = 0; \ + weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \ + weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \ + weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \ + weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \ + weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \ + weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \ + weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \ + weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \ + weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \ + \ + int8 src0 = 0; \ + int8 src1 = 0; \ + src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \ + src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \ + src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \ + src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \ + src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \ + src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \ + src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \ + src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \ + src1.s0 = convert_int(*(src_ptr + 8 * weights_stride_y)); \ + src1.s1 = convert_int(*(src_ptr + 9 * weights_stride_y)); \ + src1.s2 = convert_int(*(src_ptr + 10 * weights_stride_y)); \ + src1.s3 = convert_int(*(src_ptr + 11 * weights_stride_y)); \ + src1.s4 = convert_int(*(src_ptr + 12 * weights_stride_y)); \ + src1.s5 = convert_int(*(src_ptr + 13 * weights_stride_y)); \ + src1.s6 = convert_int(*(src_ptr + 14 * weights_stride_y)); \ + src1.s7 = convert_int(*(src_ptr + 15 * weights_stride_y)); \ + \ + acc += src0 * (int8)weights_values0.s0; \ + acc += (int8)(src0.s1234, src0.s567, src1.s0) * (int8)weights_values0.s1; \ + acc += (int8)(src0.s234, src0.s567, src1.s01) * (int8)weights_values0.s2; \ + acc += (int8)(src0.s345, src0.s67, src1.s012) * (int8)weights_values0.s3; \ + acc += (int8)(src0.s4567, src1.s0123) * (int8)weights_values0.s4; \ + acc += (int8)(src0.s567, src1.s0123, src1.s4) * (int8)weights_values0.s5; \ + acc += (int8)(src0.s67, src1.s012, src1.s345) * (int8)weights_values0.s6; \ + acc += (int8)(src0.s7, src1.s0123, src1.s456) * (int8)weights_values0.s7; \ + acc += src1 * (int8)weights_value1; \ + }) + +#define CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr) \ + ({ \ + int8 weights_values0 = 0; \ + int weights_value1 = 0; \ + weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \ + weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \ + weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \ + weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \ + weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \ + weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \ + weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \ + weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \ + weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \ + \ + int16 src0 = 0; \ + int8 src1 = 0; \ + src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \ + src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \ + src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \ + src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \ + src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \ + src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \ + src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \ + src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \ + src0.s8 = convert_int(*(src_ptr + 8 * weights_stride_y)); \ + src0.s9 = convert_int(*(src_ptr + 9 * weights_stride_y)); \ + src0.sA = convert_int(*(src_ptr + 10 * weights_stride_y)); \ + src0.sB = convert_int(*(src_ptr + 11 * weights_stride_y)); \ + src0.sC = convert_int(*(src_ptr + 12 * weights_stride_y)); \ + src0.sD = convert_int(*(src_ptr + 13 * weights_stride_y)); \ + src0.sE = convert_int(*(src_ptr + 14 * weights_stride_y)); \ + src0.sF = convert_int(*(src_ptr + 15 * weights_stride_y)); \ + src1.s0 = convert_int(*(src_ptr + 16 * weights_stride_y)); \ + src1.s1 = convert_int(*(src_ptr + 17 * weights_stride_y)); \ + src1.s2 = convert_int(*(src_ptr + 18 * weights_stride_y)); \ + src1.s3 = convert_int(*(src_ptr + 19 * weights_stride_y)); \ + src1.s4 = convert_int(*(src_ptr + 20 * weights_stride_y)); \ + src1.s5 = convert_int(*(src_ptr + 21 * weights_stride_y)); \ + src1.s6 = convert_int(*(src_ptr + 22 * weights_stride_y)); \ + src1.s7 = convert_int(*(src_ptr + 23 * weights_stride_y)); \ + \ + acc += src0.s02468ACE * (int8)weights_values0.s0; \ + acc += (int8)(src0.s1357, src0.s9BDF) * (int8)weights_values0.s1; \ + acc += (int8)(src0.s2468, src0.sACE, src1.s0) * (int8)weights_values0.s2; \ + acc += (int8)(src0.s3579, src0.sBDF, src1.s1) * (int8)weights_values0.s3; \ + acc += (int8)(src0.s468A, src0.sCE, src1.s02) * (int8)weights_values0.s4; \ + acc += (int8)(src0.s579, src0.sBDF, src1.s13) * (int8)weights_values0.s5; \ + acc += (int8)(src0.s68A, src0.sCE, src1.s024) * (int8)weights_values0.s6; \ + acc += (int8)(src0.s79B, src0.sDF, src1.s135) * (int8)weights_values0.s7; \ + acc += (int8)(src0.s8AC, src0.sE, src1.s0246) * (int8)weights_value1; \ + }) + +#elif KERNEL_SIZE == 5 #if STRIDE_X == 1 #define CONVOLUTION1x5(acc, src_ptr, weights_ptr) CONVOLUTION1x5_STRIDE1(acc, src_ptr, weights_ptr) @@ -331,7 +437,37 @@ __kernel void direct_convolution_quantized( for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d) { -#if KERNEL_SIZE == 5 +#if KERNEL_SIZE == 9 + if(y_coord < 0) + { + const int start_z = -y_coord; + for(int i = start_z; i < 9; ++i) + { + CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z)); + } + } + else if(y_coord > (SRC_HEIGHT - 9)) + { + // Avoid loading rows beyond the input height + const int end_z = SRC_HEIGHT - y_coord; + for(int i = 0; i < end_z; ++i) + { + CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z)); + } + } + else + { + CONVOLUTION1x9(values0, src_addr, weights_addr); + CONVOLUTION1x9(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 5 * (int)src_stride_z), (weights_addr + 5 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 6 * (int)src_stride_z), (weights_addr + 6 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 7 * (int)src_stride_z), (weights_addr + 7 * (int)weights_stride_z)); + CONVOLUTION1x9(values0, (src_addr + 8 * (int)src_stride_z), (weights_addr + 8 * (int)weights_stride_z)); + } +#elif KERNEL_SIZE == 5 #if(PAD_TOP == 1) || (PAD_BOTTM == 1) if(y_coord < 0) // special case Z = -1 doesn't exists { diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp index 0bf4afd81c..4acbe2dff8 100644 --- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp +++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp @@ -60,20 +60,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, "Weights feature map dimension should match the respective input's one"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2, - "Strides larger than 2 not supported for 3x3 convolution."); - - const auto data_type = input->data_type(); - - if(weights->dimension(width_idx) == 9) - { - const auto supported_data_layout = is_data_type_quantized(data_type) ? DataLayout::NCHW : DataLayout::NHWC; - const auto error_message = std::string("Only " + string_from_data_layout(supported_data_layout) + " layout is supported for 9x9 convolution with " + string_from_data_type( - data_type) - + " type"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG((supported_data_layout != data_layout), error_message.c_str()); - } + ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5 || weights->dimension(width_idx) == 9) + && std::get<0>(conv_info.stride()) > 2, + "Strides larger than 2 not supported for 3x3, 5x5, 9x9 convolution."); if(biases != nullptr) { @@ -99,6 +88,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } + const auto data_type = input->data_type(); if(is_data_type_quantized(data_type)) { const UniformQuantizationInfo iqinfo = input->quantization_info().uniform(); diff --git a/tests/validation/CL/DirectConvolutionLayer.cpp b/tests/validation/CL/DirectConvolutionLayer.cpp index 8bb2648b64..767da943f2 100644 --- a/tests/validation/CL/DirectConvolutionLayer.cpp +++ b/tests/validation/CL/DirectConvolutionLayer.cpp @@ -281,7 +281,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall9x9, CLDirectConvolutionLayerQuantizedFixture