aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-07-21 22:45:13 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-07-22 11:37:50 +0000
commitaa95ddc2abb7cef0b2edd03f7c4c9d9c6b9d7cf4 (patch)
tree7fb62d8550934c628438b1aae6de053a3f923609
parentf932d2c8409831cb9cb97a2eb65be93ad4709cd6 (diff)
downloadComputeLibrary-aa95ddc2abb7cef0b2edd03f7c4c9d9c6b9d7cf4.tar.gz
COMPMID-3535: 9x9 Direct convolution support for CL and NHWC
* Supported strides 1 and 2 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I4b9f087c0c328234159b2d1eacc2e465b3bb3c54 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3603 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h4
-rw-r--r--src/core/CL/cl_kernels/direct_convolution_quantized.cl140
-rw-r--r--src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp18
-rw-r--r--tests/validation/CL/DirectConvolutionLayer.cpp2
4 files changed, 145 insertions, 19 deletions
diff --git a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h
index b01ce6c0e8..5281a0c306 100644
--- a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h
+++ b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h
@@ -54,7 +54,7 @@ public:
* 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3
* 3x3 convolution with stride_x = 1/2, stride_y = 1/2
* 5x5 convolution with stride_x = 1/2, stride_y = 1/2
- * 9x9 convolution with stride_x = 1/2, stride_y = 1/2, data_layout=NHWC
+ * 9x9 convolution with stride_x = 1/2, stride_y = 1/2
*
* @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM],
* while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32.
@@ -74,7 +74,7 @@ public:
* 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3
* 3x3 convolution with stride_x = 1/2, stride_y = 1/2
* 5x5 convolution with stride_x = 1/2, stride_y = 1/2
- * 9x9 convolution with stride_x = 1/2, stride_y = 1/2, data_layout=NHWC
+ * 9x9 convolution with stride_x = 1/2, stride_y = 1/2
*
* @param[in] compile_context The compile context to be used.
* @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM],
diff --git a/src/core/CL/cl_kernels/direct_convolution_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
index ed1b7cfe2a..8237fe1700 100644
--- a/src/core/CL/cl_kernels/direct_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
@@ -33,7 +33,113 @@
#if defined(DATA_LAYOUT_NHWC)
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr)
+#elif STRIDE_X == 2
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X */
+
+#define CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr) \
+ ({ \
+ int8 weights_values0 = 0; \
+ int weights_value1 = 0; \
+ weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+ weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+ weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+ weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+ weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+ weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+ weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+ weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+ weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+ \
+ int8 src0 = 0; \
+ int8 src1 = 0; \
+ src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \
+ src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \
+ src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \
+ src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \
+ src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \
+ src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \
+ src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \
+ src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \
+ src1.s0 = convert_int(*(src_ptr + 8 * weights_stride_y)); \
+ src1.s1 = convert_int(*(src_ptr + 9 * weights_stride_y)); \
+ src1.s2 = convert_int(*(src_ptr + 10 * weights_stride_y)); \
+ src1.s3 = convert_int(*(src_ptr + 11 * weights_stride_y)); \
+ src1.s4 = convert_int(*(src_ptr + 12 * weights_stride_y)); \
+ src1.s5 = convert_int(*(src_ptr + 13 * weights_stride_y)); \
+ src1.s6 = convert_int(*(src_ptr + 14 * weights_stride_y)); \
+ src1.s7 = convert_int(*(src_ptr + 15 * weights_stride_y)); \
+ \
+ acc += src0 * (int8)weights_values0.s0; \
+ acc += (int8)(src0.s1234, src0.s567, src1.s0) * (int8)weights_values0.s1; \
+ acc += (int8)(src0.s234, src0.s567, src1.s01) * (int8)weights_values0.s2; \
+ acc += (int8)(src0.s345, src0.s67, src1.s012) * (int8)weights_values0.s3; \
+ acc += (int8)(src0.s4567, src1.s0123) * (int8)weights_values0.s4; \
+ acc += (int8)(src0.s567, src1.s0123, src1.s4) * (int8)weights_values0.s5; \
+ acc += (int8)(src0.s67, src1.s012, src1.s345) * (int8)weights_values0.s6; \
+ acc += (int8)(src0.s7, src1.s0123, src1.s456) * (int8)weights_values0.s7; \
+ acc += src1 * (int8)weights_value1; \
+ })
+
+#define CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr) \
+ ({ \
+ int8 weights_values0 = 0; \
+ int weights_value1 = 0; \
+ weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+ weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+ weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+ weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+ weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+ weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+ weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+ weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+ weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+ \
+ int16 src0 = 0; \
+ int8 src1 = 0; \
+ src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \
+ src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \
+ src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \
+ src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \
+ src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \
+ src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \
+ src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \
+ src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \
+ src0.s8 = convert_int(*(src_ptr + 8 * weights_stride_y)); \
+ src0.s9 = convert_int(*(src_ptr + 9 * weights_stride_y)); \
+ src0.sA = convert_int(*(src_ptr + 10 * weights_stride_y)); \
+ src0.sB = convert_int(*(src_ptr + 11 * weights_stride_y)); \
+ src0.sC = convert_int(*(src_ptr + 12 * weights_stride_y)); \
+ src0.sD = convert_int(*(src_ptr + 13 * weights_stride_y)); \
+ src0.sE = convert_int(*(src_ptr + 14 * weights_stride_y)); \
+ src0.sF = convert_int(*(src_ptr + 15 * weights_stride_y)); \
+ src1.s0 = convert_int(*(src_ptr + 16 * weights_stride_y)); \
+ src1.s1 = convert_int(*(src_ptr + 17 * weights_stride_y)); \
+ src1.s2 = convert_int(*(src_ptr + 18 * weights_stride_y)); \
+ src1.s3 = convert_int(*(src_ptr + 19 * weights_stride_y)); \
+ src1.s4 = convert_int(*(src_ptr + 20 * weights_stride_y)); \
+ src1.s5 = convert_int(*(src_ptr + 21 * weights_stride_y)); \
+ src1.s6 = convert_int(*(src_ptr + 22 * weights_stride_y)); \
+ src1.s7 = convert_int(*(src_ptr + 23 * weights_stride_y)); \
+ \
+ acc += src0.s02468ACE * (int8)weights_values0.s0; \
+ acc += (int8)(src0.s1357, src0.s9BDF) * (int8)weights_values0.s1; \
+ acc += (int8)(src0.s2468, src0.sACE, src1.s0) * (int8)weights_values0.s2; \
+ acc += (int8)(src0.s3579, src0.sBDF, src1.s1) * (int8)weights_values0.s3; \
+ acc += (int8)(src0.s468A, src0.sCE, src1.s02) * (int8)weights_values0.s4; \
+ acc += (int8)(src0.s579, src0.sBDF, src1.s13) * (int8)weights_values0.s5; \
+ acc += (int8)(src0.s68A, src0.sCE, src1.s024) * (int8)weights_values0.s6; \
+ acc += (int8)(src0.s79B, src0.sDF, src1.s135) * (int8)weights_values0.s7; \
+ acc += (int8)(src0.s8AC, src0.sE, src1.s0246) * (int8)weights_value1; \
+ })
+
+#elif KERNEL_SIZE == 5
#if STRIDE_X == 1
#define CONVOLUTION1x5(acc, src_ptr, weights_ptr) CONVOLUTION1x5_STRIDE1(acc, src_ptr, weights_ptr)
@@ -331,7 +437,37 @@ __kernel void direct_convolution_quantized(
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+ if(y_coord < 0)
+ {
+ const int start_z = -y_coord;
+ for(int i = start_z; i < 9; ++i)
+ {
+ CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+ }
+ }
+ else if(y_coord > (SRC_HEIGHT - 9))
+ {
+ // Avoid loading rows beyond the input height
+ const int end_z = SRC_HEIGHT - y_coord;
+ for(int i = 0; i < end_z; ++i)
+ {
+ CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+ }
+ }
+ else
+ {
+ CONVOLUTION1x9(values0, src_addr, weights_addr);
+ CONVOLUTION1x9(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 5 * (int)src_stride_z), (weights_addr + 5 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 6 * (int)src_stride_z), (weights_addr + 6 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 7 * (int)src_stride_z), (weights_addr + 7 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 8 * (int)src_stride_z), (weights_addr + 8 * (int)weights_stride_z));
+ }
+#elif KERNEL_SIZE == 5
#if(PAD_TOP == 1) || (PAD_BOTTM == 1)
if(y_coord < 0) // special case Z = -1 doesn't exists
{
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 0bf4afd81c..4acbe2dff8 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -60,20 +60,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
"Weights feature map dimension should match the respective input's one");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional");
ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2,
- "Strides larger than 2 not supported for 3x3 convolution.");
-
- const auto data_type = input->data_type();
-
- if(weights->dimension(width_idx) == 9)
- {
- const auto supported_data_layout = is_data_type_quantized(data_type) ? DataLayout::NCHW : DataLayout::NHWC;
- const auto error_message = std::string("Only " + string_from_data_layout(supported_data_layout) + " layout is supported for 9x9 convolution with " + string_from_data_type(
- data_type)
- + " type");
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((supported_data_layout != data_layout), error_message.c_str());
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5 || weights->dimension(width_idx) == 9)
+ && std::get<0>(conv_info.stride()) > 2,
+ "Strides larger than 2 not supported for 3x3, 5x5, 9x9 convolution.");
if(biases != nullptr)
{
@@ -99,6 +88,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
+ const auto data_type = input->data_type();
if(is_data_type_quantized(data_type))
{
const UniformQuantizationInfo iqinfo = input->quantization_info().uniform();
diff --git a/tests/validation/CL/DirectConvolutionLayer.cpp b/tests/validation/CL/DirectConvolutionLayer.cpp
index 8bb2648b64..767da943f2 100644
--- a/tests/validation/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation/CL/DirectConvolutionLayer.cpp
@@ -281,7 +281,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall9x9, CLDirectConvolutionLayerQuantizedFixture<uin
DataType::QASYMM8)),
framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 10), QuantizationInfo(1.1f, 10) })),
QuantizedActivationFunctionsDataset),
- framework::dataset::make("DataLayout", { DataLayout::NCHW })))
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);