From c6aa49b6709edada24b1ab3bc1308e0974f9e057 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Tue, 7 Aug 2018 11:53:30 +0100 Subject: COMPMID-1344 Add grouping support to CLWeightsReshapeKernel Change-Id: Idde333308db71087ec234b3fd1eb4e36a44db46c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/143049 Reviewed-by: Gian Marco Iodice Tested-by: Jenkins --- src/core/CL/cl_kernels/convolution_layer.cl | 44 ++++++++++++++++++-------- src/core/CL/kernels/CLWeightsReshapeKernel.cpp | 18 +++++++---- 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'src/core/CL') diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl index 2b83e5adf1..9335b047fe 100644 --- a/src/core/CL/cl_kernels/convolution_layer.cl +++ b/src/core/CL/cl_kernels/convolution_layer.cl @@ -23,10 +23,11 @@ */ #include "helpers.h" -#if defined(DATA_TYPE) +#if defined(DATA_TYPE) && defined(NUM_GROUPS) /** This kernel reshapes the tensor's low three dimensions to single column * * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short + * @note The number of groups should be given as a preprocessor argument using -DNUM_GROUPS=number. e.g. -DNUM_GROUPS=2 * * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) @@ -50,6 +51,7 @@ * @param[in] height The height of the input tensor * @param[in] depth The depth of the input tensor * @param[in] total_filters Total number of filters. 4th dimension of the weights matrix + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) */ __kernel void reshape_to_columns_nchw( TENSOR3D_DECLARATION(src), @@ -57,7 +59,7 @@ __kernel void reshape_to_columns_nchw( #ifdef HAS_BIAS VECTOR_DECLARATION(bias), #endif /* HAS_BIAS */ - uint width, uint height, uint depth, uint total_filters) + uint width, uint height, uint depth, uint total_filters, uint dst_stride_z) { Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); bool is_last_thread = (get_global_id(0) == (get_global_size(0) - 1) && get_global_id(1) == (get_global_size(1) - 1) && get_global_id(2) == (get_global_size(2) - 1)); @@ -71,25 +73,39 @@ __kernel void reshape_to_columns_nchw( if(is_last_thread) { - for(uint i = 0; i < total_filters; ++i) + for(uint g = 0; g < NUM_GROUPS; ++g) { - *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr); + __global uchar *curr_group_dst = tmp_dst_ptr; + + for(uint i = 0; i < total_filters / NUM_GROUPS; ++i) + { + *((__global DATA_TYPE *)curr_group_dst) = *((__global DATA_TYPE *)tmp_src_ptr); #ifdef HAS_BIAS - *((__global DATA_TYPE *)(tmp_dst_ptr + dst_stride_y)) = *((__global DATA_TYPE *)(tmp_bias_ptr)); - tmp_bias_ptr += bias_stride_x; + *((__global DATA_TYPE *)(curr_group_dst + dst_stride_y)) = *((__global DATA_TYPE *)(tmp_bias_ptr)); + tmp_bias_ptr += bias_stride_x; #endif /* HAS_BIAS */ - tmp_src_ptr += depth * src_stride_z; - tmp_dst_ptr += dst_stride_x; + tmp_src_ptr += depth * src_stride_z; + curr_group_dst += dst_stride_x; + } + + tmp_dst_ptr += dst_stride_z; } } else { - for(uint i = 0; i < total_filters; ++i) + for(uint g = 0; g < NUM_GROUPS; ++g) { - *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr); - tmp_src_ptr += depth * src_stride_z; - tmp_dst_ptr += dst_stride_x; + __global uchar *curr_group_dst = tmp_dst_ptr; + + for(uint i = 0; i < total_filters / NUM_GROUPS; ++i) + { + *((__global DATA_TYPE *)curr_group_dst) = *((__global DATA_TYPE *)tmp_src_ptr); + tmp_src_ptr += depth * src_stride_z; + curr_group_dst += dst_stride_x; + } + + tmp_dst_ptr += dst_stride_z; } } } @@ -127,7 +143,7 @@ __kernel void reshape_to_columns_nhwc( #ifdef HAS_BIAS VECTOR_DECLARATION(bias), #endif /* HAS_BIAS */ - uint depth, uint width, uint height, uint total_filters) + uint depth, uint width, uint height, uint total_filters, uint dst_stride_z) { Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); bool is_last_thread = (get_global_id(0) == (get_global_size(0) - 1) && get_global_id(1) == (get_global_size(1) - 1) && get_global_id(2) == (get_global_size(2) - 1)); @@ -163,4 +179,4 @@ __kernel void reshape_to_columns_nhwc( } } } -#endif // defined(DATA_TYPE) \ No newline at end of file +#endif // defined(DATA_TYPE) && defined(NUM_GROUPS) \ No newline at end of file diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp index 5243c4099e..9df91fccde 100644 --- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp +++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp @@ -38,11 +38,15 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output) +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output, const unsigned int num_groups) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(num_groups == 0); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && num_groups > 1); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4 && num_groups > 1); + ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(3) % num_groups) != 0); if(biases != nullptr) { @@ -57,7 +61,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, c // Checks performed when output is configured if(output->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_weights_reshaped_shape(*input, biases != nullptr)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_weights_reshaped_shape(*input, biases != nullptr, num_groups)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); } @@ -71,7 +75,7 @@ CLWeightsReshapeKernel::CLWeightsReshapeKernel() { } -void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *biases, ICLTensor *output) +void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *biases, ICLTensor *output, const unsigned int num_groups) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); @@ -81,7 +85,7 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (biases != nullptr) ? biases->info() : nullptr, - output->info())); + output->info(), num_groups)); const DataType data_type = input->info()->data_type(); const DataLayout data_layout = input->info()->data_layout(); @@ -93,6 +97,7 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * // Create build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)); + build_opts.add_option("-DNUM_GROUPS=" + support::cpp11::to_string(num_groups)); build_opts.add_option_if(biases != nullptr, "-DHAS_BIAS"); // Create kernel @@ -106,6 +111,7 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * _kernel.setArg(idx++, _input->info()->dimension(1)); _kernel.setArg(idx++, _input->info()->dimension(2)); _kernel.setArg(idx++, _input->info()->dimension(3)); + _kernel.setArg(idx++, _output->info()->strides_in_bytes().z()); // Configure window Window win = calculate_max_window(*input->info(), Steps()); @@ -114,9 +120,9 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor * ICLKernel::configure(win); } -Status CLWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output) +Status CLWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output, const unsigned int num_groups) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, biases, output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, biases, output, num_groups)); return Status{}; } -- cgit v1.2.1