From ad9a7ed2f9969381af0b9c97438a3402e16d9483 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 16 Sep 2022 14:14:21 +0100 Subject: Rework DepthwiseConvolution heuristic on OpenCL Resolves COMPMID-5632 Change-Id: I2bdbe69a610ca2510fbd74d5d412842679299762 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8365 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Reviewed-by: Jakub Sujak Comments-Addressed: Arm Jenkins --- arm_compute/core/CL/CLHelpers.h | 8 +-- arm_compute/core/KernelDescriptors.h | 1 + src/core/CL/CLHelpers.cpp | 2 +- src/core/CL/DefaultLWSHeuristics.cpp | 21 +++++- src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl | 6 +- .../CLDepthwiseConvolutionLayerNativeKernel.cpp | 75 +++++++++++++++------- .../CLDepthwiseConvolutionLayerNativeKernel.h | 5 +- src/gpu/cl/kernels/ClDirectConv2dKernel.cpp | 2 +- .../ClDirectConvDefaultConfigBifrost.cpp | 4 +- .../ClDirectConvDefaultConfigValhall.cpp | 32 ++++----- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 37 +++++++++-- .../fixtures/DepthwiseConvolutionLayerFixture.h | 1 + 12 files changed, 134 insertions(+), 60 deletions(-) diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h index edbc705c6f..a162b1c118 100644 --- a/arm_compute/core/CL/CLHelpers.h +++ b/arm_compute/core/CL/CLHelpers.h @@ -242,13 +242,13 @@ bool get_wbsm_support_info(const cl::Device &device); */ void set_wbsm(cl::Kernel &kernel, cl_int wbsm_hint); -/* Helper function to check if we can export the weights to cl_image +/* Helper function to check if we can export the tensor to cl_image * - * @param[in] tensor Weights tensor + * @param[in] input tensor * - * @return true if we can export the weights to cl_image + * @return true if we can export the tensor to cl_image */ -bool export_weights_to_cl_image(const ITensorInfo *tensor); +bool export_to_cl_image(const ITensorInfo *tensor); /* Helper function to force unroll with pragma when any of the input values (iterations) are greater than @ref max_manual_loop_unrolling * diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h index c45be9c06f..cacbef25ea 100644 --- a/arm_compute/core/KernelDescriptors.h +++ b/arm_compute/core/KernelDescriptors.h @@ -106,6 +106,7 @@ struct DWCComputeKernelInfo { unsigned int n0{ 0 }; /**< Number of columns processed by each thread */ unsigned int m0{ 0 }; /**< Number of rows processed by each thread */ + bool export_input_to_cl_image{ false }; /**< Export input to cl_image */ bool export_weights_to_cl_image{ false }; /**< Export the weights to cl_image */ }; diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index 94675d60cc..b31864211c 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -441,7 +441,7 @@ void set_wbsm(cl::Kernel &kernel, cl_int wbsm_hint) ARM_COMPUTE_ERROR_ON(err != CL_SUCCESS); } -bool export_weights_to_cl_image(const ITensorInfo *tensor) +bool export_to_cl_image(const ITensorInfo *tensor) { if(tensor->tensor_shape()[0] % 4) { diff --git a/src/core/CL/DefaultLWSHeuristics.cpp b/src/core/CL/DefaultLWSHeuristics.cpp index c082d7fbf9..c739b9dc03 100644 --- a/src/core/CL/DefaultLWSHeuristics.cpp +++ b/src/core/CL/DefaultLWSHeuristics.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -68,6 +68,21 @@ cl::NDRange get_direct_lws(size_t gws_x, size_t gws_y, size_t gws_z) return cl::NDRange(8, 4, 1); } } + +cl::NDRange get_dwc_lws(size_t gws_x, size_t gws_y, size_t gws_z) +{ + ARM_COMPUTE_UNUSED(gws_y); + ARM_COMPUTE_UNUSED(gws_z); + + if(gws_x < 32) + { + return cl::NDRange(gws_x, 4, 4); + } + else + { + return cl::NDRange(8, 4, 2); + } +} } // namespace namespace arm_compute @@ -92,6 +107,10 @@ cl::NDRange get_default_lws_for_type(CLKernelType kernel_type, cl::NDRange gws) { return get_winograd_lws(gws_x, gws_y, gws_z); } + case CLKernelType::DEPTHWISE: + { + return get_dwc_lws(gws_x, gws_y, gws_z); + } default: { return CLKernelLibrary::get().default_ndrange(); diff --git a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl index 8b14b27643..8a8458798e 100644 --- a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl +++ b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl @@ -145,7 +145,7 @@ __kernel void dwc_native_fp_nhwc( }) // Load tile from the src tensor (TILE A) - T_LOAD_NHWC_WITH_DILATION(SRC_DATA_TYPE, 1, _IM0_A, _IN0_A, SRC_TENSOR_TYPE, src, bout, yi + yk * DILATION_Y, xi, (cout / DEPTH_MULTIPLIER), src_w, src_h, DILATION_X, 1, _IBOUNDARY_CHECK, a); + T_LOAD_NHWC_WITH_DILATION(SRC_DATA_TYPE, 1, _IM0_A, _IN0_A, SRC_TENSOR_TYPE, src, bout, yi + yk * DILATION_Y, xi, (cout / DEPTH_MULTIPLIER), SRC_WIDTH, SRC_HEIGHT, DILATION_X, 1, _IBOUNDARY_CHECK, a); TILE(WEI_DATA_TYPE, _IM0_B, _IN0_B, b); @@ -185,7 +185,7 @@ __kernel void dwc_native_fp_nhwc( { LOOP_UNROLLING(int, m0, 0, 1, M0, { - int xi_out = min(xo + M0 - 1 - m0, (int)(dst_w) - 1); + int xi_out = min(xo + M0 - 1 - m0, (int)(DST_WIDTH) - 1); VSTORE_PARTIAL(N0, PARTIAL_N0) (c[M0 - 1 - m0].v, 0, (__global DST_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + cout * sizeof(DST_DATA_TYPE) + (uint)xi_out * dst_stride_y + (uint)yo * dst_stride_z + (uint)bout * dst_stride_w)); }) @@ -194,7 +194,7 @@ __kernel void dwc_native_fp_nhwc( { LOOP_UNROLLING(int, m0, 0, 1, M0, { - int xi_out = min(xo + M0 - 1 - m0, (int)(dst_w) - 1); + int xi_out = min(xo + M0 - 1 - m0, (int)(DST_WIDTH) - 1); VSTORE(N0) (c[M0 - 1 - m0].v, 0, (__global DST_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + cout * sizeof(DST_DATA_TYPE) + (uint)xi_out * dst_stride_y + (uint)yo * dst_stride_z + (uint)bout * dst_stride_w)); }) diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp index 277cba47a6..cded31936c 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp @@ -59,7 +59,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().first > 1 && dwc_info.m0 != 1); ARM_COMPUTE_RETURN_ERROR_ON(conv_info.dilation.x() > 1 && dwc_info.m0 != 1); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((dwc_info.export_weights_to_cl_image == true) && (export_weights_to_cl_image(weights) == false), "Export to cl_image not supported!"); + ARM_COMPUTE_RETURN_ERROR_ON((dwc_info.export_input_to_cl_image == true)); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((dwc_info.export_weights_to_cl_image == true) && (export_to_cl_image(weights) == false), "Weights cannot be exported to cl_image!"); ARM_COMPUTE_RETURN_ERROR_ON((dwc_info.export_weights_to_cl_image == true) && ((dwc_info.n0 % 4) != 0)); ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().first < 1); ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().second < 1); @@ -161,7 +162,8 @@ CLDepthwiseConvolutionLayerNativeKernel::CLDepthwiseConvolutionLayerNativeKernel _depth_multiplier(1), _output_multipliers(nullptr), _output_shifts(nullptr), - _export_to_cl_image(false), + _export_input_to_cl_image(false), + _export_weights_to_cl_image(false), _is_quantized(false) { _type = CLKernelType::DEPTHWISE; @@ -192,15 +194,16 @@ void CLDepthwiseConvolutionLayerNativeKernel::configure(const CLCompileContext & const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_depthwise_convolution_shape(*(input->info()), *(weights->info()), conv_info); auto_init_if_empty(*(output->info()), input->info()->clone()->set_tensor_shape(output_shape).set_quantization_info(output->info()->quantization_info())); - _input = input; - _output = output; - _weights = weights; - _biases = biases; - _depth_multiplier = conv_info.depth_multiplier; - _output_multipliers = output_multipliers; - _output_shifts = output_shifts; - _export_to_cl_image = dwc_info.export_weights_to_cl_image; - _is_quantized = is_data_type_quantized(input->info()->data_type()); + _input = input; + _output = output; + _weights = weights; + _biases = biases; + _depth_multiplier = conv_info.depth_multiplier; + _output_multipliers = output_multipliers; + _output_shifts = output_shifts; + _export_input_to_cl_image = dwc_info.export_input_to_cl_image; + _export_weights_to_cl_image = dwc_info.export_weights_to_cl_image; + _is_quantized = is_data_type_quantized(input->info()->data_type()); const unsigned int n0 = adjust_vec_size(dwc_info.n0, output->info()->dimension(0)); const unsigned int m0 = std::min(dwc_info.m0, (unsigned int)output->info()->dimension(1)); @@ -208,8 +211,13 @@ void CLDepthwiseConvolutionLayerNativeKernel::configure(const CLCompileContext & CLBuildOptions build_opts; - // Update the padding for the weights tensor if we can export to cl_image - if(_export_to_cl_image) + // Update the padding for the input/weights tensor if we can export to cl_image + if(_export_input_to_cl_image) + { + arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(input->info()); + } + + if(_export_weights_to_cl_image) { arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(weights->info()); } @@ -234,14 +242,18 @@ void CLDepthwiseConvolutionLayerNativeKernel::configure(const CLCompileContext & build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_function))); build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(conv_info.depth_multiplier)); - build_opts.add_option("-DSRC_TENSOR_TYPE=BUFFER"); + build_opts.add_option_if_else(_export_input_to_cl_image, "-DSRC_TENSOR_TYPE=IMAGE", "-DSRC_TENSOR_TYPE=BUFFER"); // Note: SRC_DATA_TYPE must have the same data type of WEI_DATA_TYPE. In quantized, we could // have a case where the data types for the activation and weights are different. However, since the implementation // only works when both have same data type, we have to change the offset to take into account this aspect build_opts.add_option("-DSRC_DATA_TYPE=" + get_cl_type_from_data_type(_input->info()->data_type())); build_opts.add_option("-DDST_TENSOR_TYPE=BUFFER"); build_opts.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst_data_type)); - build_opts.add_option_if_else(_export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER"); + build_opts.add_option_if_else(_export_weights_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER"); + build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(_input->info()->dimension(1))); + build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(_input->info()->dimension(2))); + build_opts.add_option("-DDST_WIDTH=" + support::cpp11::to_string(_output->info()->dimension(1))); + build_opts.add_option("-DDST_HEIGHT=" + support::cpp11::to_string(_output->info()->dimension(2))); build_opts.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(_weights->info()->dimension(1))); build_opts.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(_weights->info()->dimension(2))); build_opts.add_option("-DWEI_DATA_TYPE=" + get_cl_type_from_data_type(_weights->info()->data_type())); @@ -353,24 +365,39 @@ void CLDepthwiseConvolutionLayerNativeKernel::run(const Window &window, cl::Comm Window slice = window_collapsed.first_slice_window_4D(); + cl::Image2D input_cl_image; cl::Image2D weights_cl_image; - if(_export_to_cl_image) + if(_export_input_to_cl_image || _export_weights_to_cl_image) { - const size_t image_w = _weights->info()->dimension(0) / 4; - const size_t image_h = _weights->info()->dimension(1) * _weights->info()->dimension(2) * _weights->info()->dimension(3); - const TensorShape shape2d(image_w, image_h); - const size_t image_row_pitch = _weights->info()->strides_in_bytes()[1]; - // Export cl_buffer to cl_image - weights_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), _weights->cl_buffer(), shape2d, _weights->info()->data_type(), image_row_pitch); + if(_export_input_to_cl_image) + { + const size_t image_w = _input->info()->dimension(0) / 4; + const size_t image_h = _input->info()->dimension(1) * _input->info()->dimension(2) * _input->info()->dimension(3); + const TensorShape shape2d(image_w, image_h); + const size_t image_row_pitch = _input->info()->strides_in_bytes()[1]; + input_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input->cl_buffer(), shape2d, _input->info()->data_type(), image_row_pitch); + } + + if(_export_weights_to_cl_image) + { + const size_t image_w = _weights->info()->dimension(0) / 4; + const size_t image_h = _weights->info()->dimension(1) * _weights->info()->dimension(2) * _weights->info()->dimension(3); + const TensorShape shape2d(image_w, image_h); + const size_t image_row_pitch = _weights->info()->strides_in_bytes()[1]; + weights_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), _weights->cl_buffer(), shape2d, _weights->info()->data_type(), image_row_pitch); + } } unsigned int idx = 0; + if(_export_input_to_cl_image) + { + _kernel.setArg(idx++, input_cl_image); + } add_4d_tensor_nhwc_argument(idx, _input); add_4d_tensor_nhwc_argument(idx, _output); - - if(_export_to_cl_image) + if(_export_weights_to_cl_image) { _kernel.setArg(idx++, weights_cl_image); } diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h index eeed115832..5352f685ea 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -103,7 +103,8 @@ private: unsigned int _depth_multiplier{ 0 }; const ICLTensor *_output_multipliers{}; const ICLTensor *_output_shifts{}; - bool _export_to_cl_image { true }; + bool _export_input_to_cl_image{ false }; + bool _export_weights_to_cl_image{ true }; bool _is_quantized{ false }; }; } // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp index 722c802138..fd14f009e1 100644 --- a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp +++ b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp @@ -94,7 +94,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co { ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16, "K0 can only be: 4, 8, and 16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_weights_to_cl_image(weights), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_to_cl_image(weights), "Export to CLImage is not supported for this weight configuration"); } } diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp index 4ea198133b..ba176f8c5f 100644 --- a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp @@ -159,7 +159,7 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_default_ desc.k0 = 8; - desc.export_weights_to_cl_image = export_weights_to_cl_image(wei); + desc.export_weights_to_cl_image = export_to_cl_image(wei); } return desc; @@ -183,7 +183,7 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_default_ desc.k0 = 8; - desc.export_weights_to_cl_image = export_weights_to_cl_image(wei); + desc.export_weights_to_cl_image = export_to_cl_image(wei); } return desc; diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp index d87cada159..ad94678335 100644 --- a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp @@ -77,15 +77,15 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G78_f32( if(src->data_layout() == DataLayout::NHWC) { // Get the output shape - const TensorShape wei_shape = wei->tensor_shape(); - const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); - const bool export_to_cl_image = export_weights_to_cl_image(wei); + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_weights_to_cl_image = export_to_cl_image(wei); const int32_t ofm = dst_shape[0]; const int32_t m = dst_shape[1] * dst_shape[2]; const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; - desc.export_weights_to_cl_image = export_to_cl_image; + desc.export_weights_to_cl_image = export_weights_to_cl_image; if(dst_shape[0] <= 4) { @@ -138,15 +138,15 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G78_f16( if(src->data_layout() == DataLayout::NHWC) { // Get the output shape - const TensorShape wei_shape = wei->tensor_shape(); - const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); - const bool export_to_cl_image = export_weights_to_cl_image(wei); + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_weights_to_cl_image = export_to_cl_image(wei); const int32_t ofm = dst_shape[0]; const int32_t m = dst_shape[1] * dst_shape[2]; const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; - desc.export_weights_to_cl_image = export_to_cl_image; + desc.export_weights_to_cl_image = export_weights_to_cl_image; if(dst_shape[0] <= 4) { @@ -232,14 +232,14 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G57_f32( if(src->data_layout() == DataLayout::NHWC) { // Get the output shape - const TensorShape wei_shape = wei->tensor_shape(); - const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); - const bool export_to_cl_image = export_weights_to_cl_image(wei); + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_weights_to_cl_image = export_to_cl_image(wei); const int32_t m = dst_shape[1] * dst_shape[2]; const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; - desc.export_weights_to_cl_image = export_to_cl_image; + desc.export_weights_to_cl_image = export_weights_to_cl_image; if(dst_shape[0] <= 4) { @@ -292,15 +292,15 @@ DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G57_f16( if(src->data_layout() == DataLayout::NHWC) { // Get the output shape - const TensorShape wei_shape = wei->tensor_shape(); - const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); - const bool export_to_cl_image = export_weights_to_cl_image(wei); + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_weights_to_cl_image = export_to_cl_image(wei); const int32_t ofm = dst_shape[0]; const int32_t m = dst_shape[1] * dst_shape[2]; const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; - desc.export_weights_to_cl_image = export_to_cl_image; + desc.export_weights_to_cl_image = export_weights_to_cl_image; if(dst_shape[0] <= 4) { diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index 8546471fdd..3eadaee0de 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -44,7 +44,7 @@ namespace { bool export_weights_to_cl_image_heuristic(const ITensorInfo *weights, unsigned int depth_multiplier, GPUTarget gpu_target) { - if(!export_weights_to_cl_image(weights)) + if(!export_to_cl_image(weights)) { return false; } @@ -75,9 +75,12 @@ bool export_weights_to_cl_image_heuristic(const ITensorInfo *weights, unsigned i return true; } -void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, unsigned int depth_multiplier, +void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const ITensorInfo *input, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, + unsigned int depth_multiplier, GPUTarget gpu_target) { + ARM_COMPUTE_UNUSED(input); + if(!is_data_type_float(weights->data_type())) { dwc_compute_info.export_weights_to_cl_image = false; @@ -97,6 +100,7 @@ void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, // Floating point path // First check if we can export to cl_image. + dwc_compute_info.export_input_to_cl_image = false; dwc_compute_info.export_weights_to_cl_image = export_weights_to_cl_image_heuristic(weights, depth_multiplier, gpu_target); // Set n0 @@ -135,7 +139,28 @@ void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const size_t idx_w = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH); const size_t kernel_w = weights->tensor_shape()[idx_w]; - dwc_compute_info.m0 = (kernel_w >= 9) || (kernel_w == 1) ? 1 : 2; + if((kernel_w >= 9) || (kernel_w == 1)) + { + dwc_compute_info.m0 = 1; + } + else + { + if(weights->data_type() == DataType::F16) + { + if((input->dimension(1) % 5) == 0) + { + dwc_compute_info.m0 = 5; + } + else + { + dwc_compute_info.m0 = 4; + } + } + else + { + dwc_compute_info.m0 = 2; + } + } } else { @@ -237,7 +262,7 @@ void CLDepthwiseConvolutionLayer::configure(const CLCompileContext &compile_cont } DWCComputeKernelInfo dwc_native_compute_info; - initialize_dwc_native_compute_info(dwc_native_compute_info, weights_to_use->info(), conv_info, dilation, depth_multiplier, gpu_target); + initialize_dwc_native_compute_info(dwc_native_compute_info, input->info(), weights_to_use->info(), conv_info, dilation, depth_multiplier, gpu_target); const ConvolutionInfo conv_kernel_info{ conv_info, depth_multiplier, act_info, dilation }; @@ -322,7 +347,7 @@ Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITe ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U))); DWCComputeKernelInfo dwc_native_compute_info; - initialize_dwc_native_compute_info(dwc_native_compute_info, &permuted_weights, conv_info, dilation, depth_multiplier, gpu_target); + initialize_dwc_native_compute_info(dwc_native_compute_info, input, &permuted_weights, conv_info, dilation, depth_multiplier, gpu_target); ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); @@ -331,7 +356,7 @@ Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITe else { DWCComputeKernelInfo dwc_native_compute_info; - initialize_dwc_native_compute_info(dwc_native_compute_info, weights, conv_info, dilation, depth_multiplier, gpu_target); + initialize_dwc_native_compute_info(dwc_native_compute_info, input, weights, conv_info, dilation, depth_multiplier, gpu_target); ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info, &output_multipliers_shifts_info)); } diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index 9fd973ad20..58e5c528e7 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -482,6 +482,7 @@ public: DWCComputeKernelInfo dwc_info; dwc_info.n0 = _n0; dwc_info.m0 = _conv_info.stride().first == 1 && _dilation.x() == 1 ? 8 : 1; + dwc_info.export_input_to_cl_image = false; dwc_info.export_weights_to_cl_image = _export_to_cl_image; const ConvolutionInfo conv_kernel_info -- cgit v1.2.1