From 7657224de2b697a8a92cccf26d98e53ccd7c1a03 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Wed, 4 Apr 2018 17:44:26 +0100 Subject: COMPMID-926 Add depth multiplier support to NEON/CL/GLES depthwise convolution Change-Id: I03f32c62350e5ea43e77bb15fc5a832d83719e3b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/126657 Tested-by: Jenkins Reviewed-by: Michele DiGiorgio Reviewed-by: Georgios Pinitas --- src/core/CL/cl_kernels/depthwise_convolution.cl | 30 ++++++++++------ .../cl_kernels/depthwise_convolution_quantized.cl | 2 ++ .../CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp | 5 ++- .../CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp | 13 ++++--- src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp | 9 ++--- .../cs_shaders/depthwise_convolution3x3.cs | 6 +++- .../GCDepthwiseConvolutionLayer3x3Kernel.cpp | 32 ++++------------- .../NEDepthwiseConvolutionLayer3x3Kernel.cpp | 41 ++++++++++++---------- src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp | 19 +++++----- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 23 +++++++----- .../functions/GCDepthwiseConvolutionLayer.cpp | 4 +-- .../NEON/functions/NEDepthwiseConvolutionLayer.cpp | 27 ++++++++------ 12 files changed, 115 insertions(+), 96 deletions(-) (limited to 'src') diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl index 07e67f4f2c..21c28539ef 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution.cl @@ -24,6 +24,7 @@ #include "helpers.h" +#if defined(DEPTH_MULTIPLIER) #if defined(CONV_STRIDE_X) #if CONV_STRIDE_X == 1 @@ -192,6 +193,8 @@ __kernel void depthwise_convolution_3x3( Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); #endif //defined(HAS_BIAS) + src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; float3 weights_values0 = vload3(0, (__global float *)(weights.ptr + offset.s0)); float3 weights_values1 = vload3(0, (__global float *)(weights.ptr + offset.s1)); @@ -312,7 +315,7 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32( float2 pixels3 = 0.0f; __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0); + __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; // Load the weights float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y)); @@ -407,7 +410,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( float2 pixels1 = 0.0f; __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0); + __global uchar *src_addr = src.ptr - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; // Load the weights float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y)); @@ -446,6 +449,8 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f32( vstore2(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y)); } +#endif // defined(DEPTH_MULTIPLIER) + #if defined(SRC_WIDTH) && defined(DATA_TYPE) /** This kernel reshapes each of the tensor's low three dimensions to single rows. * @@ -501,11 +506,11 @@ __kernel void depthwise_weights_reshape( } #endif //defined(SRC_WIDTH) && defined(DATA_TYPE) -#if defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(DATA_TYPE) && defined(PAD_VALUE) +#if defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(DATA_TYPE) && defined(PAD_VALUE) && defined(DEPTH_MULTIPLIER) /** This kernel performs a reshaping of the input tensor to a tensor used to perform depthwise convolution using vector to matrix multiplication. * * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float - * @note The convolution information must be passed at compile time using -DSTRIDE_X, -DSTRIDE_Y, -DPAD_LEFT, -DPAD_TOP, -DPAD_RIGHT, -DPAD_BOTTOM, -DKERNEL_WIDHT, -DKERNEL_HEIGHT, -DSRC_WIDTH, -DSRC_HEIGHT + * @note The convolution information must be passed at compile time using -DSTRIDE_X, -DSTRIDE_Y, -DPAD_LEFT, -DPAD_TOP, -DPAD_RIGHT, -DPAD_BOTTOM, -DKERNEL_WIDHT, -DKERNEL_HEIGHT, -DSRC_WIDTH, -DSRC_HEIGHT, -DDEPTH_MULTIPLIER * * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QS16/F16/F32 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) @@ -534,7 +539,7 @@ __kernel void depthwise_im2col(TENSOR3D_DECLARATION(src), TENSOR3D_DECLARATION(d const int src_x = -PAD_LEFT + src_pixel_linear % max_initial_x; const int src_y = -PAD_TOP + src_pixel_linear / max_initial_x * STRIDE_Y; - const int src_z = get_global_id(2); + const int src_z = get_global_id(2) / DEPTH_MULTIPLIER; __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + src_z * src_stride_z; __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst.ptr)); @@ -558,7 +563,7 @@ __kernel void depthwise_im2col(TENSOR3D_DECLARATION(src), TENSOR3D_DECLARATION(d #endif // defined(HAS_BIAS) } -#endif //defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_WIDTH) && defined(DATA_TYPE) && defined(PAD_VALUE) +#endif //defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_WIDTH) && defined(DATA_TYPE) && defined(PAD_VALUE) && defined(DEPTH_MULTIPLIER) #if defined(CONV_WIDTH) && defined(CONV_HEIGHT) && defined(DATA_TYPE) @@ -597,7 +602,7 @@ __kernel void depthwise_vector_to_tensor( #endif //defined(CONV_WIDTH) && defined(CONV_HEIGHT) && defined(DATA_TYPE) -#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) +#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) #if defined(CONV_STRIDE_X) #if CONV_STRIDE_X == 1 #define convolution1x3_f16 convolution1x3_stride_1_f16 @@ -716,6 +721,8 @@ inline half4 convolution3x3_f16( return pixels; } +#if defined(DEPTH_MULTIPLIER) + /** This OpenCL kernel computes the depthwise convolution 3x3 * * @param[in] src_ptr Pointer to the source image. Supported data types: F16 @@ -764,6 +771,8 @@ __kernel void depthwise_convolution_3x3_f16( Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); #endif //defined(HAS_BIAS) + src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y; half3 weights_values0 = vload3(0, (__global half *)(weights.ptr + offset.s0)); half3 weights_values1 = vload3(0, (__global half *)(weights.ptr + offset.s1)); @@ -778,6 +787,7 @@ __kernel void depthwise_convolution_3x3_f16( vstore4(pixels, 0, (__global half *)dst.ptr); } +#endif // defined(DEPTH_MULTIPLIER) #endif // defined(CONV_STRIDE_X) /** This OpenCL kernel is optimized for Bifrost architectures and computes the 16bit floating point depthwise convolution 3x3 @@ -838,7 +848,7 @@ __kernel void depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16( half4 pixels3 = 0.0f; __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0); + __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; // Load the weights half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y)); @@ -935,7 +945,7 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( half4 pixels1 = 0.0f; __global uchar *weights_addr = (__global uchar *)weights.ptr; - __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0); + __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; // Load the weights half3 weights_row0 = vload3(0, (__global half *)(weights_addr + 0 * weights_stride_y)); @@ -969,4 +979,4 @@ __kernel void depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16( vstore4(pixels0, 0, (__global half *)(dst.ptr + 0 * dst_stride_y)); vstore4(pixels1, 0, (__global half *)(dst.ptr + 1 * dst_stride_y)); } -#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) +#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER) diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl index a0c0a8b1fb..ccb3a1ffe2 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl @@ -126,6 +126,8 @@ __kernel void depthwise_convolution_3x3_quantized_nchw( int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2)))); #endif //defined(HAS_BIAS) + src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z; + uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y); uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y); uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y); diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp index de68ceda11..1997a901fe 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp @@ -50,6 +50,7 @@ BorderSize CLDepthwiseConvolutionLayer3x3NCHWKernel::border_size() const } void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); @@ -73,7 +74,7 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, } // Get convolved dimensions - const TensorShape output_shape = compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info); + const TensorShape output_shape = compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier); // Output auto inizialitation if not yet initialized auto_init_if_empty(*output->info(), @@ -84,6 +85,7 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, input->info()->quantization_info()); ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); + ARM_COMPUTE_ERROR_ON(output->info()->dimension(2) != weights->info()->dimension(2)); _input = input; _output = output; @@ -98,6 +100,7 @@ void CLDepthwiseConvolutionLayer3x3NCHWKernel::configure(const ICLTensor *input, // Set build options ARM_COMPUTE_ERROR_ON(_conv_stride_x < 1 || _conv_stride_x > 3); CLBuildOptions build_opts; + build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier)); build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x)); build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS"); diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp index d783b9e159..a02b84fba1 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp @@ -41,7 +41,7 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8); @@ -50,6 +50,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU), "For QASYMM8 only relu, lower bounded relu and lower-upper bounded relu are supported"); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + ARM_COMPUTE_RETURN_ERROR_ON(depth_multiplier > 1); // COMPMID-1071 Add depth multiplier support for NHWC ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(1) != 3 || weights->dimension(2) != 3); if(biases != nullptr) @@ -61,7 +62,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, if(output->total_size() != 0) { - const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info); + const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); } @@ -105,12 +106,13 @@ BorderSize CLDepthwiseConvolutionLayer3x3NHWCKernel::border_size() const } void CLDepthwiseConvolutionLayer3x3NHWCKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); // Get convolved dimensions - const TensorShape output_shape = compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info); + const TensorShape output_shape = compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier); // Output auto inizialitation if not yet initialized auto_init_if_empty(*output->info(), @@ -120,7 +122,7 @@ void CLDepthwiseConvolutionLayer3x3NHWCKernel::configure(const ICLTensor *input, input->info()->fixed_point_position(), input->info()->quantization_info()); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, act_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info)); const unsigned int conv_stride_x = conv_info.stride().first; ARM_COMPUTE_ERROR_ON(conv_stride_x < 1 || conv_stride_x > 2); @@ -208,9 +210,10 @@ void CLDepthwiseConvolutionLayer3x3NHWCKernel::configure(const ICLTensor *input, } Status CLDepthwiseConvolutionLayer3x3NHWCKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier, ActivationLayerInfo act_info) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, biases, output, conv_info, act_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, biases, output, conv_info, depth_multiplier, act_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), weights->clone().get(), output->clone().get(), conv_info).first); return Status{}; diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp index a0784dcad6..0aef52f791 100644 --- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp @@ -42,13 +42,13 @@ CLDepthwiseIm2ColKernel::CLDepthwiseIm2ColKernel() { } -void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2)); + ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != output->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0))); _input = input; @@ -68,6 +68,7 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1))); build_opts.add_option("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width)); build_opts.add_option("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height)); + build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier)); build_opts.add_option_if(has_bias, "-DHAS_BIAS"); build_opts.add_option_if_else(is_data_type_quantized_asymmetric(input->info()->data_type()), "-DPAD_VALUE=" + support::cpp11::to_string(input->info()->quantization_info().offset), @@ -85,8 +86,8 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu } // Configure kernel window - Window win = calculate_max_window(*input->info(), Steps()); - // The CLDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped + Window win = calculate_max_window(*output->info(), Steps()); + // CLDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape())); ICLKernel::configure(win); diff --git a/src/core/GLES_COMPUTE/cs_shaders/depthwise_convolution3x3.cs b/src/core/GLES_COMPUTE/cs_shaders/depthwise_convolution3x3.cs index adfc126c95..134cc1060f 100644 --- a/src/core/GLES_COMPUTE/cs_shaders/depthwise_convolution3x3.cs +++ b/src/core/GLES_COMPUTE/cs_shaders/depthwise_convolution3x3.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -108,6 +108,8 @@ void main() uint z_index = gl_GlobalInvocationID.z; TENSOR_ITERATOR_ADVANCE_IN_BYTES(weights_iter, z_index * weights_attrs.stride_z); + src_iter.current_offset_in_bytes -= int((z_index - z_index / uint(DEPTH_MULTIPLIER)) * src_attrs.step_z); + vec4 w[3]; w[0] = LOAD_UNPACK4_CURRENT_ITEM_HALF(weights_ptr, weights_iter); w[1] = LOAD_UNPACK4_HALF(weights_ptr, TENSOR3D_OFFSET(weights_iter, 0, 1, 0)); @@ -263,6 +265,8 @@ void main() uint z_index = gl_GlobalInvocationID.z; TENSOR_ITERATOR_ADVANCE_IN_BYTES(weights_iter, z_index * weights_attrs.stride_z); + src_iter.current_offset_in_bytes -= int((z_index - z_index / uint(DEPTH_MULTIPLIER)) * src_attrs.step_z); + vec4 w[3]; w[0] = LOAD_UNPACK4_CURRENT_ITEM_HALF(weights_ptr, weights_iter); w[1] = LOAD_UNPACK4_HALF(weights_ptr, TENSOR3D_OFFSET(weights_iter, 0, 1, 0)); diff --git a/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp index 9343268d9e..c2374096a2 100644 --- a/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp +++ b/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp @@ -33,31 +33,10 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" using namespace arm_compute; - -namespace -{ -/** Calculates expected output shape dimension - * - * @param[in] Input shape - * - * @return Expected output shape - */ -TensorShape get_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info) -{ - unsigned int output_width = 0; - unsigned int output_height = 0; - - std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), conv_info); - - TensorShape output_shape = input_shape; - output_shape.set(0, output_width); - output_shape.set(1, output_height); - - return output_shape; -} -} // namespace +using namespace arm_compute::misc::shape_calculator; GCDepthwiseConvolutionLayer3x3Kernel::GCDepthwiseConvolutionLayer3x3Kernel() : _border_size(0), _input(), _output(), _weights(), _biases(), _conv_stride_x(0), _conv_stride_y(0), _conv_pad_left(0), _conv_pad_top(0), _lws(gles::NDRange(1U, 1U, 1U)) @@ -69,7 +48,8 @@ BorderSize GCDepthwiseConvolutionLayer3x3Kernel::border_size() const return _border_size; } -void GCDepthwiseConvolutionLayer3x3Kernel::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info) +void GCDepthwiseConvolutionLayer3x3Kernel::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, + unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); @@ -83,7 +63,7 @@ void GCDepthwiseConvolutionLayer3x3Kernel::configure(const IGCTensor *input, con } // Get convolved dimensions - TensorShape output_shape = get_output_shape(input->info()->tensor_shape(), weights->info()->tensor_shape(), conv_info); + const TensorShape output_shape = compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier); // Output auto inizialitation if not yet initialized auto_init_if_empty(*output->info(), @@ -93,6 +73,7 @@ void GCDepthwiseConvolutionLayer3x3Kernel::configure(const IGCTensor *input, con input->info()->fixed_point_position()); ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); + ARM_COMPUTE_ERROR_ON(output->info()->dimension(2) != weights->info()->dimension(2)); _input = input; _output = output; @@ -108,6 +89,7 @@ void GCDepthwiseConvolutionLayer3x3Kernel::configure(const IGCTensor *input, con ARM_COMPUTE_ERROR_ON(_conv_stride_x < 1 || _conv_stride_x > 3); std::set options; + options.emplace("#define DEPTH_MULTIPLIER " + support::cpp11::to_string(depth_multiplier)); options.emplace("#define LOCAL_SIZE_X " + support::cpp11::to_string(_lws[0])); options.emplace("#define LOCAL_SIZE_Y " + support::cpp11::to_string(_lws[1])); options.emplace("#define LOCAL_SIZE_Z " + support::cpp11::to_string(_lws[2])); diff --git a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp index 49c67d19bb..8cdf175d8a 100644 --- a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp +++ b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp @@ -52,13 +52,14 @@ class convolver_3x3 { public: static void convolve(const Window &window, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) + const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { const int input_offset = -input->info()->quantization_info().offset; const int weights_offset = -weights->info()->quantization_info().offset; const int input_stride_x = input->info()->strides_in_bytes().x(); const int input_stride_y = input->info()->strides_in_bytes().y(); + const int input_stride_z = input->info()->strides_in_bytes().z(); const int output_stride_y = output->info()->strides_in_bytes().y(); const int kernel_stride_y = weights->info()->strides_in_bytes().y(); const int kernel_stride_z = weights->info()->strides_in_bytes().z(); @@ -93,7 +94,7 @@ public: int ih = 0; int oh = 0; - const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y; + const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y - (id.z() - id.z() / depth_multiplier) * input_stride_z; const uint8_t *ptr_weights_base = weights_ptr + id.z() * kernel_stride_z; const auto ptr_weights_r0 = reinterpret_cast(ptr_weights_base); @@ -125,19 +126,19 @@ public: template inline void convolve_3x3(const Window &window, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) + const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); switch(conv_stride_x) { case 1: - convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info); + convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info, depth_multiplier); break; case 2: - convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info); + convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info, depth_multiplier); break; case 3: - convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info); + convolver_3x3::convolve(window, num_elems_written_per_iteration, input, weights, output, conv_info, depth_multiplier); break; default: ARM_COMPUTE_ERROR("Not implemented"); @@ -146,7 +147,7 @@ inline void convolve_3x3(const Window &window, unsigned int num_elems_written_pe } // namespace NEDepthwiseConvolutionLayer3x3Kernel::NEDepthwiseConvolutionLayer3x3Kernel() - : _border_size(0), _input(), _output(), _weights(), _conv_info(), _convolver(nullptr), _num_elems_written_per_iteration(0), _run_optimized(false) + : _border_size(0), _input(), _output(), _weights(), _conv_info(), _convolver(nullptr), _num_elems_written_per_iteration(0), _run_optimized(false), _depth_multiplier(1) { } @@ -155,20 +156,22 @@ BorderSize NEDepthwiseConvolutionLayer3x3Kernel::border_size() const return _border_size; } -void NEDepthwiseConvolutionLayer3x3Kernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, DataLayout data_layout) +void NEDepthwiseConvolutionLayer3x3Kernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, + DataLayout data_layout) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - _input = input; - _output = output; - _weights = weights; - _conv_info = conv_info; - _convolver = nullptr; + _input = input; + _output = output; + _weights = weights; + _conv_info = conv_info; + _depth_multiplier = depth_multiplier; + _convolver = nullptr; _run_optimized = NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(input->info()->tensor_shape(), conv_info, - input->info()->data_type(), + input->info()->data_type(), depth_multiplier, data_layout); (_run_optimized) ? configure_optimized() : configure_generic(); @@ -182,7 +185,7 @@ void NEDepthwiseConvolutionLayer3x3Kernel::run(const Window &window, const Threa (_run_optimized) ? run_optimized(window, info) : run_generic(window, info); } -bool NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(TensorShape input_shape, PadStrideInfo conv_info, DataType dt, DataLayout data_layout) +bool NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(TensorShape input_shape, PadStrideInfo conv_info, DataType dt, unsigned int depth_multiplier, DataLayout data_layout) { // Reshape input shape if in NHWC format TensorShape in_shape{ input_shape }; @@ -210,7 +213,7 @@ bool NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(Tenso bool is_valid_padding = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0); bool supported_padding = is_same_padding || is_valid_padding; - return supported_datatype && supported_strides && supported_padding; + return supported_datatype && supported_strides && supported_padding && (depth_multiplier == 1); } void NEDepthwiseConvolutionLayer3x3Kernel::generate_convolver() @@ -227,7 +230,7 @@ void NEDepthwiseConvolutionLayer3x3Kernel::configure_generic() ARM_COMPUTE_ERROR_ON(_weights->info()->dimension(0) != 3 || _weights->info()->dimension(1) != 3); // Get convolved dimensions - const TensorShape output_shape = compute_depthwise_convolution_shape(*_input->info(), *_weights->info(), _conv_info); + const TensorShape output_shape = compute_depthwise_convolution_shape(*_input->info(), *_weights->info(), _conv_info, _depth_multiplier); const DataType output_dt = (_input->info()->data_type() == DataType::QASYMM8) ? DataType::S32 : _input->info()->data_type(); // Output auto inizialitation if not yet initialized @@ -317,10 +320,10 @@ void NEDepthwiseConvolutionLayer3x3Kernel::run_generic(const Window &window, con switch(_input->info()->data_type()) { case DataType::F32: - convolve_3x3(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); + convolve_3x3(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info, _depth_multiplier); break; case DataType::QASYMM8: - convolve_3x3(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); + convolve_3x3(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info, _depth_multiplier); break; default: ARM_COMPUTE_ERROR("Not implemented"); diff --git a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp index b924d9f8bd..cfd8eacfdd 100644 --- a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp @@ -85,7 +85,7 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window) const int src_y = -pad_top + src_pixel_linear / max_initial_x * stride_y; // Get pointers - const uint8_t *const input_ptr = in.ptr() + id.z() * input_stride_z; + const uint8_t *const input_ptr = in.ptr() + id.z() / _depth_multiplier * input_stride_z; auto output_ptr = reinterpret_cast(out.ptr()); const int height = src_y + _kernel_dims.height; const int width = src_x + _kernel_dims.width; @@ -114,24 +114,25 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window) } NEDepthwiseIm2ColKernel::NEDepthwiseIm2ColKernel() - : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias() + : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias(), _depth_multiplier(1) { } -void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2)); + ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != output->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0))); - _input = input; - _output = output; - _kernel_dims = kernel_dims; - _conv_info = conv_info; - _has_bias = has_bias; + _input = input; + _output = output; + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _has_bias = has_bias; + _depth_multiplier = depth_multiplier; // Configure kernel window Window win = calculate_max_window(*input->info(), Steps()); diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index 0276b37e09..ea2f93b85d 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -39,7 +39,8 @@ CLDepthwiseConvolutionLayer3x3::CLDepthwiseConvolutionLayer3x3() { } -void CLDepthwiseConvolutionLayer3x3::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, ActivationLayerInfo act_info) +void CLDepthwiseConvolutionLayer3x3::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, + ActivationLayerInfo act_info) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); @@ -54,7 +55,7 @@ void CLDepthwiseConvolutionLayer3x3::configure(ICLTensor *input, const ICLTensor } _kernel->set_target(CLScheduler::get().target()); - _kernel->configure(input, weights, biases, output, conv_info, act_info); + _kernel->configure(input, weights, biases, output, conv_info, depth_multiplier, act_info); // Configure border handler PixelValue &&zero_value(0.f); @@ -77,11 +78,11 @@ CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayer() { } -void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info) +void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != weights->info()->dimension(2)); + ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != weights->info()->dimension(2)); const size_t weights_w = weights->info()->dimension(0); const size_t weights_h = weights->info()->dimension(1); @@ -95,11 +96,15 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w const GPUTarget gpu_target = CLScheduler::get().target(); // Calculate output shape - TensorShape dwc_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info); + TensorShape output_shape = shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); // Output width and height - const unsigned int conv_w = dwc_output_shape.x(); - const unsigned int conv_h = dwc_output_shape.y(); + const unsigned int conv_w = output_shape.x(); + const unsigned int conv_h = output_shape.y(); // Set up intermediate tensors const size_t patch_size = weights_w * weights_h + ((append_bias) ? 1 : 0); @@ -112,7 +117,7 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w shape_im2col.set(2, weights_z); _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); _im2col_kernel.set_target(gpu_target); - _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias); + _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier); // Weights reshape configuration const TensorShape shape_weights_reshape(patch_size, weights_z); @@ -128,7 +133,7 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out)); _v2mm_kernel.set_target(gpu_target); _v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output); - _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(dwc_output_shape)); + _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape)); _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h); // Output staged configuration diff --git a/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp index 9cba37110b..7121654a75 100644 --- a/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp +++ b/src/runtime/GLES_COMPUTE/functions/GCDepthwiseConvolutionLayer.cpp @@ -35,10 +35,10 @@ GCDepthwiseConvolutionLayer3x3::GCDepthwiseConvolutionLayer3x3() { } -void GCDepthwiseConvolutionLayer3x3::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info) +void GCDepthwiseConvolutionLayer3x3::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, weights, biases, output, conv_info); + k->configure(input, weights, biases, output, conv_info, depth_multiplier); _kernel = std::move(k); // Configure border handler diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp index 8691fb9f76..0a977ad08d 100644 --- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp @@ -41,7 +41,7 @@ NEDepthwiseConvolutionLayer3x3::NEDepthwiseConvolutionLayer3x3() { } -void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info) +void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); @@ -53,6 +53,7 @@ void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *we _is_optimized = NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(input->info()->tensor_shape(), conv_info, input->info()->data_type(), + depth_multiplier, input->info()->data_layout()); _are_weights_reshaped = false; _is_nchw = input->info()->data_layout() == DataLayout::NCHW; @@ -70,7 +71,7 @@ void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *we _permute_weights.configure(weights, &_weights_hwio, PermutationVector(2U, 0U, 1U)); // Configure optimized depthwise - _dwc_kernel.configure(&_input_nhwc, &_weights_hwio, &_output_nhwc, conv_info, DataLayout::NHWC); + _dwc_kernel.configure(&_input_nhwc, &_weights_hwio, &_output_nhwc, conv_info, depth_multiplier, DataLayout::NHWC); // Configure the function to transform the convoluted output to ACL's native ordering format NCHW _permute_output.configure(&_output_nhwc, output, PermutationVector(1U, 2U, 0U)); @@ -82,7 +83,7 @@ void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *we } else { - _dwc_kernel.configure(input, weights, output, conv_info, DataLayout::NHWC); + _dwc_kernel.configure(input, weights, output, conv_info, depth_multiplier, DataLayout::NHWC); } } else @@ -96,7 +97,7 @@ void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *we } // Configure depthwise convolution kernel - _dwc_kernel.configure(input, weights, (_is_quantized) ? &_accumulator : output, conv_info); + _dwc_kernel.configure(input, weights, (_is_quantized) ? &_accumulator : output, conv_info, depth_multiplier); // Configure border handler _border_handler.configure(input, _dwc_kernel.border_size(), BorderMode::CONSTANT, zero_value); @@ -175,11 +176,11 @@ NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer() { } -void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info) +void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != weights->info()->dimension(2)); + ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != weights->info()->dimension(2)); const size_t weights_w = weights->info()->dimension(0); const size_t weights_h = weights->info()->dimension(1); @@ -193,11 +194,15 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh bool append_bias = (biases != nullptr) && !_is_quantized; // Calculate output shape - TensorShape dwc_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info); + TensorShape output_shape = shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); // Output width and height - const unsigned int conv_w = dwc_output_shape.x(); - const unsigned int conv_h = dwc_output_shape.y(); + const unsigned int conv_w = output_shape.x(); + const unsigned int conv_h = output_shape.y(); // Set up intermediate tensors const size_t patch_size = weights_w * weights_h + (append_bias ? 1 : 0); @@ -209,7 +214,7 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh shape_im2col.set(1, conv_size); shape_im2col.set(2, weights_z); _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); - _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias); + _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier); // Weights reshape configuration const TensorShape shape_weights_reshape(patch_size, weights_z); @@ -224,7 +229,7 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh shape_v2mm_out.set(2, 1); _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out)); _v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output); - _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(dwc_output_shape)); + _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape)); _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h); // Output staged configuration -- cgit v1.2.1