From 15bc8485ef463508838a549b7e8518bf05883155 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Tue, 8 Dec 2020 14:34:00 +0000 Subject: [Review Shape] CLDepthwiseConvolutionLayer mismatches - Fixed a bug that corrected the number of dimensions of a TensorShape for added trailing 1s - Avoided adding offset_first_element for the Depthwise 3x3 NCHW OpenCL kernels, since it wouldn't align with the window which is based on the output - Adjusted padding requirements along the x for Depthwise 3x3 NCHW. The kernel should always add 2 * dilation_(x/y) to the num_elems_read_x/y - Adjusted the kernel's border_size given to the border handler at function level - Added the dataset that previously made the tests fail Resolves: COMPMID-4041 Change-Id: Ifab7d38b263f12173fcc96a5f0bd3375756c3c53 Signed-off-by: Giorgio Arena Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4673 Comments-Addressed: Arm Jenkins Reviewed-by: SiCong Li Tested-by: Arm Jenkins --- arm_compute/core/Dimensions.h | 17 ++++--- arm_compute/core/TensorShape.h | 2 +- src/core/CL/cl_kernels/depthwise_convolution.cl | 58 +++++++++++----------- .../CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp | 10 ++-- tests/datasets/DepthwiseConvolutionLayerDataset.h | 9 ++-- .../DilatedDepthwiseConvolutionLayerDataset.h | 10 ++-- .../CL/DepthwiseConvolutionLayerNative.cpp | 2 +- 7 files changed, 59 insertions(+), 49 deletions(-) diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index 960238c267..0e6e1f6681 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -68,14 +68,19 @@ public: /** Accessor to set the value of one of the dimensions. * - * @param[in] dimension Dimension for which the value is set. - * @param[in] value Value to be set for the dimension. + * @param[in] dimension Dimension for which the value is set. + * @param[in] value Value to be set for the dimension. + * @param[in] increase_dim_unit (Optional) Set to true if unit dimension increase the number of dimensions (e.g. for Coordinates), false otherwise (e.g. for TensorShapes) */ - void set(size_t dimension, T value) + void set(size_t dimension, T value, bool increase_dim_unit = true) { ARM_COMPUTE_ERROR_ON(dimension >= num_max_dimensions); - _id[dimension] = value; - _num_dimensions = std::max(_num_dimensions, dimension + 1); + _id[dimension] = value; + // Don't increase the number of dimensions if the new dimension is 1 + if(increase_dim_unit || value != 1) + { + _num_dimensions = std::max(_num_dimensions, dimension + 1); + } } /** Alias to access the size of the first dimension */ T x() const diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index b455a07767..fe3921f766 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -90,7 +90,7 @@ public: // Set the specified dimension and increase the number of dimensions if // necessary - Dimensions::set(dimension, value); + Dimensions::set(dimension, value, false); // Correct number dimensions to ignore trailing dimensions of size 1 if(apply_dim_correction) diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl index 81fa01ae99..8ce5617858 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution.cl @@ -338,7 +338,6 @@ __kernel void depthwise_convolution_3x3( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); @@ -351,7 +350,8 @@ __kernel void depthwise_convolution_3x3( __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; // Load the weights float3 weights_values0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y)); @@ -501,7 +501,6 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); @@ -515,7 +514,8 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( const int batch = get_global_id(2) / DST_CHANNELS; // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; #if(DILATION_X == 1 && DILATION_Y == 1) // Load the weights @@ -547,13 +547,13 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( #else /* DILATION_X==1 && DILATION_Y==1 */ //3x3 Convolution of elements starting in 0th row - pixels0 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src.stride_x, src.stride_y, 0, weights_addr, weights_stride_y); + pixels0 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src_stride_x, src_stride_y, 0, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 1st row - pixels1 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src.stride_x, src.stride_y, 1, weights_addr, weights_stride_y); + pixels1 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src_stride_x, src_stride_y, 1, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 2nd row - pixels2 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src.stride_x, src.stride_y, 2, weights_addr, weights_stride_y); + pixels2 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src_stride_x, src_stride_y, 2, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 3rd row - pixels3 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src.stride_x, src.stride_y, 3, weights_addr, weights_stride_y); + pixels3 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f32(src_addr, src_stride_x, src_stride_y, 3, weights_addr, weights_stride_y); #endif /* DILATION_X==1 && DILATION_Y==1 */ @@ -621,7 +621,6 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); @@ -633,7 +632,8 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( const int batch = get_global_id(2) / DST_CHANNELS; // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; #if(DILATION_X == 1 && DILATION_Y == 1) @@ -664,9 +664,9 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( #else /* DILATION_X==1 && DILATION_Y==1 */ //3x3 Convolution of elements starting in 0th row - pixels0 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f32(src_addr, src.stride_x, src.stride_y, 0, weights_addr, weights_stride_y); + pixels0 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f32(src_addr, src_stride_x, src_stride_y, 0, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 2nd row - pixels1 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f32(src_addr, src.stride_x, src.stride_y, 2, weights_addr, weights_stride_y); + pixels1 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f32(src_addr, src_stride_x, src_stride_y, 2, weights_addr, weights_stride_y); #endif /* DILATION_X==1 && DILATION_Y==1 */ #ifdef HAS_BIAS @@ -997,16 +997,16 @@ inline half4 convolution1x3_stride_3_f16(__global const uchar *left_pixel, * @return a half4 containing 4 convoluted values. */ inline half4 convolution3x3_f16( - Image *src, + __global uchar *src, uint src_stride_y, const half mat0, const half mat1, const half mat2, const half mat3, const half mat4, const half mat5, const half mat6, const half mat7, const half mat8) { half4 pixels; - pixels = convolution1x3_f16(offset(src, 0, 0), mat0, mat1, mat2); - pixels += convolution1x3_f16(offset(src, 0, DILATION_Y), mat3, mat4, mat5); - pixels += convolution1x3_f16(offset(src, 0, DILATION_Y * 2), mat6, mat7, mat8); + pixels = convolution1x3_f16(src, mat0, mat1, mat2); + pixels += convolution1x3_f16(src + DILATION_Y * src_stride_y, mat3, mat4, mat5); + pixels += convolution1x3_f16(src + DILATION_Y * 2 * src_stride_y, mat6, mat7, mat8); return pixels; } @@ -1059,7 +1059,6 @@ __kernel void depthwise_convolution_3x3_f16( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); #if defined(HAS_BIAS) @@ -1070,7 +1069,8 @@ __kernel void depthwise_convolution_3x3_f16( const int channel = get_global_id(2) % DST_CHANNELS; const int batch = get_global_id(2) / DST_CHANNELS; // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) - src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; @@ -1078,7 +1078,7 @@ __kernel void depthwise_convolution_3x3_f16( half3 weights_values1 = vload3(0, (__global half *)(weights_addr + offset.s1)); half3 weights_values2 = vload3(0, (__global half *)(weights_addr + offset.s2)); - half4 pixels = convolution3x3_f16(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2, + half4 pixels = convolution3x3_f16(src_addr, src_stride_y, weights_values0.s0, weights_values0.s1, weights_values0.s2, weights_values1.s0, weights_values1.s1, weights_values1.s2, weights_values2.s0, weights_values2.s1, weights_values2.s2); #if defined(HAS_BIAS) @@ -1137,7 +1137,6 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); @@ -1158,7 +1157,8 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; #if(DILATION_X == 1 && DILATION_Y == 1) // Load the weights @@ -1190,13 +1190,13 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( #else /* DILATION_X==1 && DILATION_Y==1 */ //3x3 Convolution of elements starting in 0th row - pixels0 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src.stride_x, src.stride_y, 0, weights_addr, weights_stride_y); + pixels0 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src_stride_x, src_stride_y, 0, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 1st row - pixels1 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src.stride_x, src.stride_y, 1, weights_addr, weights_stride_y); + pixels1 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src_stride_x, src_stride_y, 1, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 2nd row - pixels2 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src.stride_x, src.stride_y, 2, weights_addr, weights_stride_y); + pixels2 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src_stride_x, src_stride_y, 2, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 3rd row - pixels3 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src.stride_x, src.stride_y, 3, weights_addr, weights_stride_y); + pixels3 = convolution_3x3_dilation_stridex1_stridey1_bifrost_f16(src_addr, src_stride_x, src_stride_y, 3, weights_addr, weights_stride_y); #endif /* DILATION_X==1 && DILATION_Y==1 */ @@ -1260,7 +1260,6 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( #endif //defined(HAS_BIAS) ) { - Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src); Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst); Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights); @@ -1279,7 +1278,8 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( // Load relevant input and weights data ( Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER) __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z; - __global uchar *src_addr = src.ptr - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; + __global uchar *src_addr = src_ptr + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_z - batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * + (DEPTH_MULTIPLIER - 1) * src_step_z - (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z; #if(DILATION_X == 1 && DILATION_Y == 1) @@ -1309,9 +1309,9 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( #else /* DILATION_X==1 && DILATION_Y==1 */ //3x3 Convolution of elements starting in 0th row - pixels0 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f16(src_addr, src.stride_x, src.stride_y, 0, weights_addr, weights_stride_y); + pixels0 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f16(src_addr, src_stride_x, src_stride_y, 0, weights_addr, weights_stride_y); //3x3 Convolution of elements starting in 2nd row - pixels1 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f16(src_addr, src.stride_x, src.stride_y, 2, weights_addr, weights_stride_y); + pixels1 = convolution_3x3_dilation_stridex2_stridey2_bifrost_f16(src_addr, src_stride_x, src_stride_y, 2, weights_addr, weights_stride_y); #endif /* DILATION_X==1 && DILATION_Y==1 */ #ifdef HAS_BIAS diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp index 25d0d2799b..ba7a782bf1 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp @@ -211,8 +211,11 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen num_elems_read_per_iteration_x = 3 + (num_elems_written_per_iteration_x - 1) * conv_stride_x + (conv_stride_x > 1 ? 1 : 0); num_elems_read_per_iteration_y = num_elems_written_per_iteration_y + 2; } - num_elems_read_per_iteration_x += (num_elems_read_per_iteration_x - 1) * (dilation.x() - 1); - num_elems_read_per_iteration_y += (num_elems_read_per_iteration_y - 1) * (dilation.y() - 1); + // The OpenCL routine convolution1x3 does loadn(addr), loadn(addr + dilation_x) and loadn(addr + 2 * dilation_x) on the input. + // Each of the three convolution1x3 gets called by passing addr, (addr + dilation_y) and (addr + 2 * dilation_y) + // Hence we must add 2 * dilation.x/y() to the number of elements read in those axes per thread + num_elems_read_per_iteration_x += 2 * dilation.x(); + num_elems_read_per_iteration_y += 2 * dilation.y(); // Create window and update padding Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y)); @@ -267,7 +270,6 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const CLCompileContext _conv_stride_y = conv_info.stride().second; _conv_pad_left = conv_info.pad_left(); _conv_pad_top = conv_info.pad_top(); - _border_size = BorderSize(_conv_pad_top, conv_info.pad_right(), conv_info.pad_bottom(), _conv_pad_left); _output_multipliers = output_multipliers; _output_shifts = output_shifts; _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); @@ -280,6 +282,8 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const CLCompileContext ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure_internal(win_config.second); + _border_size = BorderSize(input->info()->padding()); + // Set build options CLBuildOptions build_opts; build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation()))); diff --git a/tests/datasets/DepthwiseConvolutionLayerDataset.h b/tests/datasets/DepthwiseConvolutionLayerDataset.h index ed596d6d45..86804fb4c6 100644 --- a/tests/datasets/DepthwiseConvolutionLayerDataset.h +++ b/tests/datasets/DepthwiseConvolutionLayerDataset.h @@ -161,9 +161,9 @@ class SmallDepthwiseConvolutionLayerDataset3x3 final : public DepthwiseConvoluti public: SmallDepthwiseConvolutionLayerDataset3x3() { - add_config(TensorShape(3U, 3U, 2U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 0)); + add_config(TensorShape(1U, 1U, 2U), Size2D(3U, 3U), PadStrideInfo(1, 1, 2, 2)); add_config(TensorShape(7U, 8U, 3U, 2U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 0)); - add_config(TensorShape(21U, 31U, 9U, 4U), Size2D(3U, 3U), PadStrideInfo(1, 1, 1, 0)); + add_config(TensorShape(32U, 31U, 9U, 4U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 0)); // Asymmetric padding add_config(TensorShape(33U, 27U, 11U), Size2D(3U, 3U), PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR)); } @@ -186,7 +186,6 @@ class LargeDepthwiseConvolutionLayerDataset3x3 final : public DepthwiseConvoluti public: LargeDepthwiseConvolutionLayerDataset3x3() { - add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 1)); add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 1, 1)); add_config(TensorShape(21U, 31U, 9U, 4U), Size2D(3U, 3U), PadStrideInfo(1, 2, 1, 0)); add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 2, 0, 1)); @@ -202,6 +201,8 @@ public: add_config(TensorShape(233U, 277U, 55U), Size2D(3U, 3U), PadStrideInfo(1, 2, 0, 0)); add_config(TensorShape(333U, 277U, 77U, 5U), Size2D(3U, 3U), PadStrideInfo(2, 3, 0, 1)); add_config(TensorShape(177U, 311U, 22U), Size2D(3U, 3U), PadStrideInfo(2, 1, 1, 1)); + // Width and height are a multipile of the processing tile size + add_config(TensorShape(32U, 21U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 1)); } }; @@ -269,4 +270,4 @@ public: } // namespace datasets } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_DATASET */ +#endif /* ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_DATASET */ \ No newline at end of file diff --git a/tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h b/tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h index 38762f35d0..9e2a3cf548 100644 --- a/tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h +++ b/tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -68,9 +68,9 @@ class SmallDepthwiseDilatedConvolutionLayerDataset3x3 final : public DepthwiseCo public: SmallDepthwiseDilatedConvolutionLayerDataset3x3() { - add_config(TensorShape(7U, 7U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 1, 0), Size2D(2U, 2U)); + add_config(TensorShape(1U, 1U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 2, 2), Size2D(2U, 2U)); add_config(TensorShape(7U, 7U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 2, 0), Size2D(2U, 2U)); - add_config(TensorShape(7U, 7U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 3, 0), Size2D(2U, 2U)); + add_config(TensorShape(16U, 7U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 1), Size2D(2U, 2U)); // Different strides and dilations add_config(TensorShape(7U, 7U, 1U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 0), Size2D(2U, 2U)); @@ -119,7 +119,7 @@ class LargeDepthwiseDilatedConvolutionLayerDataset3x3 final : public DepthwiseCo public: LargeDepthwiseDilatedConvolutionLayerDataset3x3() { - add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 1), Size2D(2U, 1U)); + add_config(TensorShape(32U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 0, 1), Size2D(2U, 1U)); add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 1, 1, 1), Size2D(2U, 2U)); add_config(TensorShape(21U, 31U, 9U, 4U), Size2D(3U, 3U), PadStrideInfo(1, 2, 1, 0), Size2D(2U, 2U)); add_config(TensorShape(33U, 27U, 11U, 3U), Size2D(3U, 3U), PadStrideInfo(1, 2, 0, 1), Size2D(2U, 1U)); @@ -140,4 +140,4 @@ public: } // namespace datasets } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_DILATED_CONVOLUTION_LAYER_DATASET */ +#endif /* ARM_COMPUTE_TEST_DILATED_CONVOLUTION_LAYER_DATASET */ \ No newline at end of file diff --git a/tests/validation/CL/DepthwiseConvolutionLayerNative.cpp b/tests/validation/CL/DepthwiseConvolutionLayerNative.cpp index b1cd379574..ac4ed0b3ca 100644 --- a/tests/validation/CL/DepthwiseConvolutionLayerNative.cpp +++ b/tests/validation/CL/DepthwiseConvolutionLayerNative.cpp @@ -63,7 +63,7 @@ RelativeTolerance rel_tolerance_f16(half_float::half(0.01f)); constexpr float abs_tolerance_f16(0.03f); /** Width values to test - Precommit */ -const auto width_values_precommit = framework::dataset::make("width", { 37U } ); +const auto width_values_precommit = framework::dataset::make("width", { 1U, 17U, 32U } ); /** Width values to test - Nightly */ const auto width_values_nightly = framework::dataset::make("width", { 53U, 47U } ); -- cgit v1.2.1