diff options
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClConv2d.cpp')
-rw-r--r-- | src/runtime/gpu/cl/operators/ClConv2d.cpp | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/src/runtime/gpu/cl/operators/ClConv2d.cpp b/src/runtime/gpu/cl/operators/ClConv2d.cpp index 4cd65290f3..2f4d673d9c 100644 --- a/src/runtime/gpu/cl/operators/ClConv2d.cpp +++ b/src/runtime/gpu/cl/operators/ClConv2d.cpp @@ -36,6 +36,35 @@ #include <memory> +namespace +{ +/** Get the suitable kernel size for using direct convolution method with NHWC data layout. + * + * @note Direct convolution should be executed when the kernel has the spatial dimensions greater than or equal to the value returned by this function + * + * @param[in] gpu_target GPU target + * + * @return the suitable kernel size for using direct convolution method with NHWC data layout + */ +size_t get_direct_conv_kernel_threshold_nhwc(arm_compute::GPUTarget gpu_target) +{ + switch(gpu_target) + { + case arm_compute::GPUTarget::G76: + case arm_compute::GPUTarget::G77: + case arm_compute::GPUTarget::G78: + return 5; + case arm_compute::GPUTarget::G71: + case arm_compute::GPUTarget::G72: + case arm_compute::GPUTarget::MIDGARD: + case arm_compute::GPUTarget::BIFROST: + return 7; + default: + return 5; + } +} +} // namespace + namespace arm_compute { namespace opencl @@ -132,7 +161,6 @@ ConvolutionMethod ClConv2d::get_convolution_method(const ITensorInfo *src, const ARM_COMPUTE_ERROR_ON_NULLPTR(dst); ARM_COMPUTE_ERROR_ON_NULLPTR(weights); ARM_COMPUTE_UNUSED(weights_info); - ARM_COMPUTE_UNUSED(gpu_target); const PadStrideInfo conv_info = conv2d_info.conv_info; const ActivationLayerInfo act_info = conv2d_info.act_info; @@ -206,8 +234,9 @@ ConvolutionMethod ClConv2d::get_convolution_method(const ITensorInfo *src, const } else { - const bool is_direct_valid = bool(ClDirectConv2d::validate(src, weights, nullptr, dst, conv_info, act_info)); - const bool is_wino_valid = bool(ClWinogradConv2d::validate(src, weights, nullptr, dst, conv_info, act_info, enable_fast_math)); + const bool is_direct_valid = bool(ClDirectConv2d::validate(src, weights, nullptr, dst, conv_info, act_info)); + const bool is_wino_valid = bool(ClWinogradConv2d::validate(src, weights, nullptr, dst, conv_info, act_info, enable_fast_math)); + const size_t kernel_sz_direct_conv_thr = get_direct_conv_kernel_threshold_nhwc(gpu_target); // SRGAN case if((src->dimension(idx_h) > 720U) && (dst->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9) && (conv_info.pad_top() < 3) @@ -219,8 +248,9 @@ ConvolutionMethod ClConv2d::get_convolution_method(const ITensorInfo *src, const // Floating-point case: GeMM/Direct/Winograd if(is_data_type_float(src->data_type())) { - const bool is_large_kernel_sz = (weights->dimension(idx_w) >= 7) && (weights->dimension(idx_h) >= 7); + const bool is_large_kernel_sz = (weights->dimension(idx_w) >= kernel_sz_direct_conv_thr) && (weights->dimension(idx_h) >= kernel_sz_direct_conv_thr); const bool is_ifm_ge_16 = src->dimension(idx_c) >= 16; + const bool is_ifm_gt_ofm = src->dimension(idx_c) > weights->dimension(3U); // Run Winograd if valid and IFM >= 16 if(is_wino_valid && is_ifm_ge_16) @@ -228,7 +258,7 @@ ConvolutionMethod ClConv2d::get_convolution_method(const ITensorInfo *src, const return ConvolutionMethod::WINOGRAD; } // Run Direct for Large kernel size - if(is_large_kernel_sz && is_ifm_ge_16 && is_direct_valid) + if(is_large_kernel_sz && is_ifm_ge_16 && is_direct_valid && is_ifm_gt_ofm) { return ConvolutionMethod::DIRECT; } |