From 7657224de2b697a8a92cccf26d98e53ccd7c1a03 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Wed, 4 Apr 2018 17:44:26 +0100 Subject: COMPMID-926 Add depth multiplier support to NEON/CL/GLES depthwise convolution Change-Id: I03f32c62350e5ea43e77bb15fc5a832d83719e3b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/126657 Tested-by: Jenkins Reviewed-by: Michele DiGiorgio Reviewed-by: Georgios Pinitas --- src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp') diff --git a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp index b924d9f8bd..cfd8eacfdd 100644 --- a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp @@ -85,7 +85,7 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window) const int src_y = -pad_top + src_pixel_linear / max_initial_x * stride_y; // Get pointers - const uint8_t *const input_ptr = in.ptr() + id.z() * input_stride_z; + const uint8_t *const input_ptr = in.ptr() + id.z() / _depth_multiplier * input_stride_z; auto output_ptr = reinterpret_cast(out.ptr()); const int height = src_y + _kernel_dims.height; const int width = src_x + _kernel_dims.width; @@ -114,24 +114,25 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window) } NEDepthwiseIm2ColKernel::NEDepthwiseIm2ColKernel() - : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias() + : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias(), _depth_multiplier(1) { } -void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) +void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2)); + ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != output->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0))); - _input = input; - _output = output; - _kernel_dims = kernel_dims; - _conv_info = conv_info; - _has_bias = has_bias; + _input = input; + _output = output; + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _has_bias = has_bias; + _depth_multiplier = depth_multiplier; // Configure kernel window Window win = calculate_max_window(*input->info(), Steps()); -- cgit v1.2.1